From a193b7a7e094458dbc6781857ba18971de3b0870 Mon Sep 17 00:00:00 2001 From: Emily Veenhuis Date: Mon, 4 May 2026 13:11:42 -0400 Subject: [PATCH 1/8] Update random effects tests based on 2026-01-21 ADEPT weights --- .../tests/test_alignment_adm_component.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/align_system/tests/test_alignment_adm_component.py b/align_system/tests/test_alignment_adm_component.py index bf427581..72e255ec 100644 --- a/align_system/tests/test_alignment_adm_component.py +++ b/align_system/tests/test_alignment_adm_component.py @@ -5,7 +5,9 @@ from align_system.algorithms.alignment_adm_component import ( MedicalUrgencyAlignmentADMComponent, MedicalUrgencyAlignmentWeightedADMComponent, - RandomEffectsModelAlignmentADMComponent) + RandomEffectsModelAlignmentADMComponent, + MultinomialRandomEffectsModelAlignmentADMComponent, +) @pytest.mark.parametrize( ("alignment_fn_class"), @@ -902,9 +904,9 @@ class TestRandomEffectsModelAlignmentADMComponent: "Treat Patient A": {"medical": 0.947157191, "merit": 0.0}, "Treat Patient B": {"medical": 0.012495865, "merit": 1.0}, # Medical delta = 0.947157191-0.012495865 = 0.934661326 - # Z-scaled medical delta = (0.934661326 - 0.433409) / 0.308294 = 1.62589063037 + # Z-scaled medical delta = (0.934661326 - 0.428961) / 0.301250 = 1.67867328133 # Attribute score = 0.0 - # Z-scaled attribute = (0.0 - 0.357632) / 0.27947 = -1.27967939314 + # Z-scaled attribute = (0.0 - 0.337618) / 0.272520 = -1.23887421107 }, None, { @@ -919,8 +921,8 @@ class TestRandomEffectsModelAlignmentADMComponent: ] }, ], - # Y_ij = 0.5 + 0.85*1.62589063037-0.3*-1.27967939314 = 2.26591085376 - # P_choose_a = e^2.26591085376/(1+e^2.26591085376) = 0.90601416437 + # Y_ij = 0.5 + 0.85*1.67867328133-0.3*-1.23887421107 = 2.29853455245 + # P_choose_a = e^2.29853455245/(1+e^2.29853455245) = 0.90875559851 }, "Treat Patient A", does_not_raise(), @@ -931,9 +933,9 @@ class TestRandomEffectsModelAlignmentADMComponent: "Treat Patient A": {"medical": 0.947157191, "merit": 0.0, "affiliation": 0.5}, "Treat Patient B": {"medical": 0.012495865, "merit": 1.0, "affiliation": 0.25}, # Medical delta = 0.947157191-0.012495865 = 0.934661326 - # Z-scaled medical delta = (0.934661326 - 0.433409) / 0.308294 = 1.62589063037 + # Z-scaled medical delta = (0.934661326 - 0.428961) / 0.301250 = 1.67867328133 # Attribute score = 0.0 - # Z-scaled attribute = (0.0 - 0.357632) / 0.27947 = -1.27967939314 + # Z-scaled attribute = (0.0 - 0.337618) / 0.272520 = -1.23887421107 }, { "merit": 1.0, @@ -960,8 +962,8 @@ class TestRandomEffectsModelAlignmentADMComponent: ] } ], - # Y_ij = 0.5 + 0.85*1.62589063037-0.3*-1.27967939314 = 2.26591085376 - # P_choose_a = e^2.26591085376/(1+e^2.26591085376) = 0.90601416437 + # Y_ij = 0.5 + 0.85*1.67867328133-0.3*-1.23887421107= 2.29853455245 + # P_choose_a = e^2.29853455245/(1+e^2.29853455245) = 0.90875559851 }, "Treat Patient A", does_not_raise(), @@ -1072,11 +1074,11 @@ def test_choice_selection( ( "merit", 0.5, 0.85, -0.3, 0.934661326, 0.0, - # Z-scaled medical delta = (0.934661326 - 0.433409) / 0.308294 = 1.62589063037 - # Z-scaled attribute = (0.0 - 0.357632) / 0.27947 = -1.27967939314 - # Y_ij = 0.5 + 0.85*1.62589063037-0.3*-1.27967939314 = 2.26591085376 - # P_choose_a = e^2.26591085376/(1+e^2.26591085376) = 0.90601416437 - 0.90601416437 + # Z-scaled medical delta = (0.934661326 - 0.428961) / 0.301250 = 1.67867328133 + # Z-scaled attribute = (0.0 - 0.337618) / 0.272520 = -1.23887421107 + # Y_ij = 0.5 + 0.85*1.67867328133-0.3*-1.23887421107 = 2.29853455245 + # P_choose_a = e^2.29853455245/(1+e^2.29853455245) = 0.90601416437 + 0.90875559851 ), ( "affiliation", 2.1875, 2.36875, 0.015625, From d89bd96ababed52fc239837c569a20f86fd71b2a Mon Sep 17 00:00:00 2001 From: Emily Veenhuis Date: Wed, 6 May 2026 10:44:37 -0400 Subject: [PATCH 2/8] Add multinomial random effects alignment function --- CHANGELOG.md | 14 +- .../algorithms/alignment_adm_component.py | 146 ++++++- .../tests/test_alignment_adm_component.py | 386 ++++++++++++++++++ 3 files changed, 532 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86c12d76..c52fe123 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,16 @@ This changelog follows the specifications detailed in: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release. +## Unreleased + +### Added + +* Added a multinomial version of the random effects alignment function (note this is still not multi-kdma) + +### Fixed + +* Updated unit tests for the random effects alignment function based on updated merit focus weights provided by ADEPT 2026-01-21 + ## 0.5.10 ### Added @@ -27,7 +37,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm * Added support for raw text system prompts for the pipeline baseline ADM * Added tagging ADM configs - + ### Changed * Refactored ICL selection strategies to reduce duplication; factored out similarity strategies @@ -36,7 +46,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm * Removed deprecated `tags` property from Dialog * Removed lots of old pre-Hydra/Outlines code * Made comparative regression reasoning length configurable - + ### Fixed * Removed non-determinism from midpoint alignment functions diff --git a/align_system/algorithms/alignment_adm_component.py b/align_system/algorithms/alignment_adm_component.py index 9cafbbb5..14a4fd50 100644 --- a/align_system/algorithms/alignment_adm_component.py +++ b/align_system/algorithms/alignment_adm_component.py @@ -1,4 +1,5 @@ import math +import numpy as np from align_system.utils import call_with_coerced_args, logging from align_system.algorithms.abstracts import ADMComponent @@ -355,21 +356,14 @@ def _compute_p_choose_a(self, kdma, intercept, medical_weight, attr_weight, raw_ y_ij = intercept + medical_weight*medical_delta + attr_weight*attr_score return math.exp(y_ij) / (1 + math.exp(y_ij)) - def run( + def _preproccess_predictions( self, attribute_prediction_scores, alignment_target, attribute_relevance=None, ): """ - Align using ADEPT's random effects model theory - - attribute_prediction_scores: dict[str, dict[str, float | list[float]]] - Dictionary of choices mapped to KMDA value predictions, including medical - urgency prediction - alignment_target: alignment target info - attribute_relevance: dict[str, float | list[float]] - Dictionary of probe level KDMA relevance predictions + Preprocess run input arguments for future usage, while also completing some input validation. """ if alignment_target is None: raise RuntimeError( @@ -382,8 +376,6 @@ def run( target_kdmas = [dict(t) for t in target_kdmas] choices = list(attribute_prediction_scores.keys()) - if len(choices) != 2: - raise NotImplementedError("This alignment function has not yet been implemented for !=2 choices") # Compute averages of predicted values predictions = [] @@ -410,7 +402,30 @@ def run( if len(relevant_kdmas) != 1: raise RuntimeError("This alignment function can only be used when 1 attribute is relevant") - # Guaranteed to only have 2 choices at this point due to earlier checks + return target_kdmas, choices, predictions, probe_relevance, relevant_kdmas + + def run( + self, + attribute_prediction_scores, + alignment_target, + attribute_relevance=None, + ): + """ + Align using ADEPT's random effects model theory + + attribute_prediction_scores: dict[str, dict[str, float | list[float]]] + Dictionary of choices mapped to KMDA value predictions, including medical + urgency prediction + alignment_target: alignment target info + attribute_relevance: dict[str, float | list[float]] + Dictionary of probe level KDMA relevance predictions + """ + target_kdmas, choices, predictions, probe_relevance, relevant_kdmas = self._preproccess_predictions(attribute_prediction_scores, alignment_target, attribute_relevance) + + # This alignment function only works for binary (2-choice probes) + if len(choices) != 2: + raise NotImplementedError("This alignment function has not yet been implemented for !=2 choices") + opt_a, opt_b = predictions for target_kdma in target_kdmas: @@ -456,3 +471,110 @@ def run( return (opt_a["choice"], best_sample_idx, alignment_info) else: return (opt_b["choice"], best_sample_idx, alignment_info) + + +class MultinomialRandomEffectsModelAlignmentADMComponent(RandomEffectsModelAlignmentADMComponent): + def __init__( + self, + attributes=None + ): + super().__init__(attributes) + + def _log_odds(self, p, eps=1e-12): + p = np.clip(p, eps, 1-eps) # avoid division by 0 or log(0) + log_odds = np.log(p / (1 - p)) + np.fill_diagonal(log_odds, 0) # explicitly set diagonal to 0 so that it doesn't affect later computations + return log_odds + + def _stable_softmax(self, scores): + e_scores = np.exp(scores - np.max(scores)) # Subtracting the max for numerical stability + return e_scores / e_scores.sum(axis=0) + + def _composite_probs(self, p_matrix): + """Combines sub-problem probabities into per-choice composite probabilities""" + log_odds = self._log_odds(p_matrix) + scores = np.sum(log_odds, axis=1) + return self._stable_softmax(scores) + + def run( + self, + attribute_prediction_scores, + alignment_target, + attribute_relevance=None, + ): + """ + Align using a multinomial expansion of ADEPT's random effects model theory + + attribute_prediction_scores: dict[str, dict[str, float | list[float]]] + Dictionary of choices mapped to KMDA value predictions, including medical + urgency prediction + alignment_target: alignment target info + attribute_relevance: dict[str, float | list[float]] + Dictionary of probe level KDMA relevance predictions + """ + target_kdmas, choices, predictions, probe_relevance, relevant_kdmas = self._preproccess_predictions(attribute_prediction_scores, alignment_target, attribute_relevance) + + # Only one option, decision always has to be the same + if len(choices) == 1: + return ( + predictions[0]["choice"], + 0, # TODO: best sample index + { + "source": type(self).__name__, + "p_choices": np.ones((1,)), + }, + ) + + # We iterate to find the relevant KDMA, this isn't multi-kdma yet (raised in preprocess) + for target_kdma in target_kdmas: + kdma = target_kdma["kdma"] + if kdma != relevant_kdmas[0]: + continue + + intercept = None + medical_weight = None + attr_weight = None + if target_kdma["parameters"] is not None: + for param in target_kdma["parameters"]: + if param["name"] == "intercept": + intercept = param["value"] + if param["name"] == "medical_weight": + medical_weight = param["value"] + if param["name"] == "attr_weight": + attr_weight = param["value"] + if intercept is None or medical_weight is None or attr_weight is None: + raise RuntimeError("This alignment function requires an intercept, medical weight, and attr weight") + + # Loop over options pairwise + p_matrix = np.ones((len(choices), len(choices))) + for choice_idx_a, opt_a in enumerate(predictions[:-1]): + for choice_idx_b, opt_b in enumerate(predictions[choice_idx_a+1:]): + raw_medical_delta = opt_a[med_urg_str] - opt_b[med_urg_str] + + # Choices should be sorted by descending medical need due to model assumptions + flip_order = False + if raw_medical_delta < 0: + flip_order = True + raw_medical_delta *= -1 + primary, secondary = (opt_a, opt_b) if not flip_order else (opt_b, opt_a) + raw_attr_score = secondary[kdma] if kdma == "search" else primary[kdma] + + p_choose_primary = self._compute_p_choose_a( + kdma, intercept, medical_weight, attr_weight, raw_medical_delta, raw_attr_score) + + p_matrix[choice_idx_a][choice_idx_b] = p_choose_primary if not flip_order else 1 - p_choose_primary + p_matrix[choice_idx_b][choice_idx_a] = 1 - p_choose_primary if not flip_order else p_choose_primary + + p_choices = self._composite_probs(p_matrix) + + # TODO: Figure out what it means to be the best prediction for this alignment function + best_sample_idx = 0 + + alignment_info = { + "source": type(self).__name__, + "p_choices": p_choices, + } + + max_idx = np.argmax(p_choices) + + return (predictions[max_idx]["choice"], best_sample_idx, alignment_info) diff --git a/align_system/tests/test_alignment_adm_component.py b/align_system/tests/test_alignment_adm_component.py index 72e255ec..edaf0cfb 100644 --- a/align_system/tests/test_alignment_adm_component.py +++ b/align_system/tests/test_alignment_adm_component.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import unittest.mock as mock from contextlib import nullcontext as does_not_raise @@ -1118,3 +1119,388 @@ def test_compute_p_choose_a(self, kdma, intercept, medical_weight, attr_weight, alignment_fn._compute_p_choose_a(kdma, intercept, medical_weight, attr_weight, raw_medical_delta, raw_attr_score) == pytest.approx(exp_value) ) + + +class TestMultinomialRandomEffectsModelAlignmentADMComponent: + attribute_definitions = { + "KDMA_A": { + "name": "Merit Focus", + "kdma": "merit", + "description": "Test merit focus KDMA", + }, + "KDMA_B": { + "name": "Affiliation Focus", + "kdma": "affiliation", + "description": "Test affiliation focus KDMA", + }, + "KDMA_C": { + "name": "Personal Safety", + "kdma": "personal_safety", + "description": "Test personal safety KDMA", + }, + "KDMA_D": { + "name": "Search vs Stay", + "kdma": "search", + "description": "Test search vs stay KDMA", + }, + } + + @pytest.mark.parametrize( + ("attribute_prediction_scores", "attribute_relevance", "alignment_target", "exp_choice", "exp_raises"), + [ + # No alignment target + ( + { + "Choice 0": {"medical": 0.1, "KDMA_A": 0.1, "KDMA_B": 0.8}, + "Choice 1": {"medical": 0.6, "KDMA_A": 0.3, "KDMA_B": 0.5}, + }, + None, + None, + None, # Raise expected so doesn't matter + pytest.raises(RuntimeError, match=r"Assumption violated: `alignment_target` was None"), + ), + # No medical predictions + ( + { + "Choice 0": {"KDMA_A": 0.1, "KDMA_B": 0.8}, + "Choice 1": {"KDMA_A": 0.3, "KDMA_B": 0.5}, + }, + None, + { + "kdma_values": + [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.75}, + {"name": "medical_weight", "value": 0.5}, + {"name": "attr_weight", "value": -0.25}, + ] + }, + ], + }, + None, # Raise expected so doesn't matter + pytest.raises(RuntimeError, match=r"Medical Urgency predictions required"), + ), + # Target missing parameters + ( + { + "Choice 0": {"medical": 0.9, "KDMA_A": 0.1}, + "Choice 1": {"medical": 0.4, "KDMA_A": 0.9}, + }, + None, + { + "kdma_values": [{"kdma": "KDMA_A", "value": 0.7}], + }, + None, # Raise expected so doesn't matter + pytest.raises(RuntimeError, match=r"This alignment function requires an intercept, medical weight, and attr weight"), + ), + # Target missing intercept + ( + { + "Choice 0": {"medical": 0.9, "KDMA_A": 0.1}, + "Choice 1": {"medical": 0.4, "KDMA_A": 0.9}, + }, + None, + { + "kdma_values": [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "medical_weight", "value": 0.5}, + {"name": "attr_weight", "value": -0.25}, + ] + } + ], + }, + None, # Raise expected so doesn't matter + pytest.raises(RuntimeError, match=r"This alignment function requires an intercept, medical weight, and attr weight"), + ), + # Target missing medical weight + ( + { + "Choice 0": {"medical": 0.9, "KDMA_A": 0.1}, + "Choice 1": {"medical": 0.4, "KDMA_A": 0.9}, + }, + None, + { + "kdma_values": [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.75}, + {"name": "attr_weight", "value": -0.25}, + ] + } + ], + }, + None, # Raise expected so doesn't matter + pytest.raises(RuntimeError, match=r"This alignment function requires an intercept, medical weight, and attr weight"), + ), + # Target missing attr_weight + ( + { + "Choice 0": {"medical": 0.9, "KDMA_A": 0.1}, + "Choice 1": {"medical": 0.4, "KDMA_A": 0.9}, + }, + None, + { + "kdma_values": [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.75}, + {"name": "medical_weight", "value": 0.5}, + ] + } + ], + }, + None, # Raise expected so doesn't matter + pytest.raises(RuntimeError, match=r"This alignment function requires an intercept, medical weight, and attr weight"), + ), + # Multiple KDMAs relevant + ( + { + "Choice 0": {"medical": 0.9, "KDMA_A": 0.1}, + "Choice 1": {"medical": 0.4, "KDMA_A": 0.9}, + }, + None, + { + "kdma_values": [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.75}, + {"name": "medical_weight", "value": 0.5}, + {"name": "attr_weight", "value": -0.25}, + ] + }, + { + "kdma": "KDMA_B", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.75}, + {"name": "medical_weight", "value": 0.5}, + {"name": "attr_weight", "value": -0.25}, + ] + } + ], + }, + None, # Raise expected so doesn't matter + pytest.raises(RuntimeError, match=r"This alignment function can only be used when 1 attribute is relevant"), + ), + # Worked example with ADEPT + ( + { + "Treat Patient A": {"medical": 0.947157191, "merit": 0.0}, + "Treat Patient B": {"medical": 0.012495865, "merit": 1.0}, + # Medical delta = 0.947157191-0.012495865 = 0.934661326 + # Z-scaled medical delta = (0.934661326 - 0.428961) / 0.301250 = 1.67867328133 + # Attribute score = 0.0 + # Z-scaled attribute = (0.0 - 0.337618) / 0.272520 = -1.23887421107 + }, + None, + { + "kdma_values": [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.5}, + {"name": "medical_weight", "value": 0.85}, + {"name": "attr_weight", "value": -0.3}, + ] + }, + ], + # Y_ij = 0.5 + 0.85*1.67867328133-0.3*-1.23887421107 = 2.29853455245 + # P_choose_a = e^2.29853455245/(1+e^2.29853455245) = 0.90875559851 + }, + "Treat Patient A", + does_not_raise(), + ), + # Multi-KDMA, should fallback to merit + ( + { + "Treat Patient A": {"medical": 0.947157191, "merit": 0.0, "affiliation": 0.5}, + "Treat Patient B": {"medical": 0.012495865, "merit": 1.0, "affiliation": 0.25}, + # Medical delta = 0.947157191-0.012495865 = 0.934661326 + # Z-scaled medical delta = (0.934661326 - 0.428961) / 0.301250 = 1.67867328133 + # Attribute score = 0.0 + # Z-scaled attribute = (0.0 - 0.337618) / 0.272520 = -1.23887421107 + }, + { + "merit": 1.0, + "affiliation": 0.0 + }, + { + "kdma_values": [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.5}, + {"name": "medical_weight", "value": 0.85}, + {"name": "attr_weight", "value": -0.3}, + ] + }, + { + "kdma": "KDMA_B", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.75}, + {"name": "medical_weight", "value": 0.5}, + {"name": "attr_weight", "value": -0.25}, + ] + } + ], + # Y_ij = 0.5 + 0.85*1.67867328133-0.3*-1.23887421107 = 2.29853455245 + # P_choose_a = e^2.29853455245/(1+e^2.29853455245) = 0.90875559851 + }, + "Treat Patient A", + does_not_raise(), + ), + # <2 choices + ( + { + "Choice 0": {"medical": 0.1, "KDMA_B": 0.8}, + }, + None, + { + "kdma_values": + [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.75}, + {"name": "medical_weight", "value": 0.5}, + {"name": "attr_weight", "value": -0.25}, + ] + }, + ], + }, + "Choice 0", + does_not_raise(), + ), + # >2 choices + ( + { + "Choice 0": {"medical": 0.1, "merit": 0.2}, + "Choice 1": {"medical": 0.2, "merit": 0.1}, + "Choice 2": {"medical": 0.9, "merit": 0.9}, + }, + None, + { + "kdma_values": + [ + { + "kdma": "KDMA_A", + "value": None, + "parameters": [ + {"name": "intercept", "value": 0.75}, + {"name": "medical_weight", "value": 0.5}, + {"name": "attr_weight", "value": -0.25}, + ] + }, + ], + }, + "Choice 2", # Choice 2 wins for both medical and attribute + does_not_raise(), + ), + ], + ids=[ + "no target", "no medical preds", "target missing parameters", "missing intercept", "missing medical weight", + "missing attr weight", "multiple relevant KDMAs", "worked example", "multi-kdma", "<2 choices", ">2 choices", + ], + ) + def test_run(self, attribute_prediction_scores, attribute_relevance, alignment_target, exp_choice, exp_raises): + alignment_fn = MultinomialRandomEffectsModelAlignmentADMComponent( + TestMultinomialRandomEffectsModelAlignmentADMComponent.attribute_definitions + ) + + with exp_raises: + # Only checking selected choice as best sample index not yet implemented + assert alignment_fn.run(attribute_prediction_scores, alignment_target, attribute_relevance)[0] == exp_choice + + @pytest.mark.parametrize( + ("p_matrix", "exp_value"), + [ + ( + np.array([[1, 0.25, 0.6], [0.75, 1, 0.8], [0.4, 0.2, 1]]), + np.array([[0, -1.09861228867, 0.4054651081], [1.09861228867, 0, 1.38629436112], [-0.4054651081, -1.38629436112, 0]]) + ), + ( + np.array([[1, 0.55, 0.36], [0.45, 1, 0.24], [0.64, 0.76, 1]]), + np.array([[0, 0.20067069546, -0.5753641449], [-0.20067069546, 0, -1.15267950994], [0.5753641449, 1.15267950994, 0]]) + ), + ( + np.array([[1, 0.6, 0.6], [0.4, 1, 0.6], [0.4, 0.4, 1]]), + np.array([[0, 0.4054651081, 0.4054651081], [-0.4054651081, 0, 0.4054651081], [-0.4054651081, -0.4054651081, 0]]) + ), + ( + np.array([[1, 0.6, 0.6, 0.5], [0.4, 1, 0.6, 0.5], [0.4, 0.4, 1, 0.5], [0.5, 0.5, 0.5, 1]]), + np.array([[0, 0.4054651081, 0.4054651081, 0], [-0.4054651081, 0, 0.4054651081, 0], [-0.4054651081, -0.4054651081, 0, 0], [0, 0, 0, 0]]) + ) + ], + ) + def test_log_odds(self, p_matrix, exp_value): + alignment_fn = MultinomialRandomEffectsModelAlignmentADMComponent( + TestMultinomialRandomEffectsModelAlignmentADMComponent.attribute_definitions + ) + + assert np.allclose(alignment_fn._log_odds(p_matrix), exp_value) + + + @pytest.mark.parametrize( + ("scores", "exp_value"), + [ + ( + np.array([0.6, -0.3, 0.4]), + np.array([0.44937752864, 0.18270326891, 0.36791920244]) + ), + ( + np.array([0.6, -0.3, 0.4, -0.8]), + np.array([0.40454753882, 0.16447675521, 0.33121551111, 0.09976019484]) + ), + ( + np.array([-0.69314718057, 2.48490664979, -1.79175946922]), + np.array([0.03947368421, 0.94736842105, 0.01315789473]) + ), + ( + np.array([-0.37469344944, -1.3533502054, 1.72804365484]), + np.array([0.10455474162, 0.03929330002, 0.85615195834]) + ), + ], + ) + def test_softmax(self, scores, exp_value): + alignment_fn = MultinomialRandomEffectsModelAlignmentADMComponent( + TestMultinomialRandomEffectsModelAlignmentADMComponent.attribute_definitions + ) + + assert np.allclose(alignment_fn._stable_softmax(scores), exp_value) + assert np.isclose(1, np.sum(exp_value)) + + @pytest.mark.parametrize( + ("p_matrix", "exp_value"), + [ + ( + np.array([[1, 0.25, 0.6], [0.75, 1, 0.8], [0.4, 0.2, 1]]), + np.array([0.03947368421, 0.94736842105, 0.01315789473]) + ), + ( + np.array([[1, 0.55, 0.36], [0.45, 1, 0.24], [0.64, 0.76, 1]]), + np.array([0.10455474162, 0.03929330002, 0.85615195834]) + ) + ], + ) + def test_composite_probs(self, p_matrix, exp_value): + alignment_fn = MultinomialRandomEffectsModelAlignmentADMComponent( + TestMultinomialRandomEffectsModelAlignmentADMComponent.attribute_definitions + ) + + assert np.allclose(alignment_fn._composite_probs(p_matrix), exp_value) From 69639ee69b577623a8104f1c202796e3e8da7b52 Mon Sep 17 00:00:00 2001 From: Emily Veenhuis Date: Tue, 12 May 2026 14:22:32 -0400 Subject: [PATCH 3/8] Add initial open world driver --- align_system/drivers/itm_open_world.py | 386 +++++++++++++++++++++++++ 1 file changed, 386 insertions(+) create mode 100644 align_system/drivers/itm_open_world.py diff --git a/align_system/drivers/itm_open_world.py b/align_system/drivers/itm_open_world.py new file mode 100644 index 00000000..88b5cd59 --- /dev/null +++ b/align_system/drivers/itm_open_world.py @@ -0,0 +1,386 @@ +import os +import json +import random +import re +from copy import deepcopy + +from rich.highlighter import JSONHighlighter +import hydra +from omegaconf import DictConfig, OmegaConf +from swagger_client.models import ActionTypeEnum +from timeit import default_timer as timer + +from align_system.utils import logging +from align_system.utils.version import get_version +from align_system.exceptions import SceneSkipException + +log = logging.getLogger(__name__) +JSON_HIGHLIGHTER = JSONHighlighter() + + +class ITMOpenWorldDriver: + def __init__(self, + apply_action_filtering=True, + sort_available_actions=False): + self.apply_action_filtering = apply_action_filtering + self.sort_available_actions = sort_available_actions + + def _expand_action_by_character(self, action, characters, restricted_character_ids): + expanded_actions = [] + + for character in characters: + if character.id in restricted_character_ids: + continue + + new_action = deepcopy(action) + new_action.character_id = character.id + new_action.unstructured = re.sub(r"(a )?Patient", character.name, action.unstructured) + + expanded_actions.append(new_action) + + return expanded_actions + + + def drive(self, cfg): + interface = cfg.interface + adm = cfg.adm.instance + + # Using the hydra generated output directory for the run + output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + + save_input_output_to_path = None + if cfg.save_input_output: + save_input_output_to_path = os.path.join(output_dir, "input_output.json") + + save_alignment_score_to_path = None + if cfg.save_scoring_output: + save_alignment_score_to_path = os.path.join(output_dir, "scores.json") + + save_alignment_targets_to_path = None + if cfg.save_alignment_targets: + save_alignment_targets_to_path = os.path.join(output_dir, "targets") + os.mkdir(save_alignment_targets_to_path) + + save_timing_to_path = None + if cfg.save_timing: + save_timing_to_path = os.path.join(output_dir, "timing.json") + + if cfg.get('force_determinism', False) or self.sort_available_actions: + log.info("Setting `sort_available_actions` to True") + sort_available_actions = True + else: + sort_available_actions = False + + # HACK: need to invoke 'load_model' for ADMs that require it, + # maybe it makes more sense to load_model in the init method for + # those ADMs + if hasattr(adm, 'load_model'): + adm.load_model() + + # Capture inputs and outputs in a similar format to what's used by + # our internal evaluation framework code + inputs_outputs = [] + + # Write version sidecar once at the start of the run + meta = {"version": get_version()} + username = getattr(interface, 'username', None) + if username is not None: + meta["username"] = username + with open(os.path.join(output_dir, "meta.json"), 'w') as f: + json.dump(meta, f, indent=2) + + session_alignment_scores = [] + + # Capture time it takes to choose each action + action_times = { "scenarios": [] } + def _compute_time_stats(times_s): + n_times = len(times_s) + total_time_s = sum(times_s) + return { + "n_actions_taken": n_times, + "total_time_s": total_time_s, + "avg_time_s": total_time_s / n_times if n_times else 0., + "max_time_s": max(times_s) if n_times else 0., + "raw_times_s": times_s + } + + # Loop through available scenarios + while scenario := interface.start_scenario(): + if scenario.id() == '': + log.info("Next scenario ID is blank, assuming we're done, exiting") + break + log.info(f'[bold]*Scenario ID*[/bold]: {scenario.id()}') + + # Reset any decision or chat history for a new scenario + if hasattr(adm, 'reset_history'): + log.info("[bold]*Resetting choice history*[/bold]") + adm.reset_history() + + if 'alignment_target' in cfg: + alignment_target = cfg.alignment_target + # Alignment targets specified in hydra configs require + # some nested conversion to dict (from OmegaConf objects) + # otherwise this can cause some downstream issues with + # serialization + alignment_target.kdma_values = [OmegaConf.to_container(c) + if isinstance(c, DictConfig) else c + for c in alignment_target.kdma_values] + elif cfg.align_to_target: + alignment_target = scenario.get_alignment_target() + else: + alignment_target = None + + log.info('[bold]*ALIGNMENT TARGET*[/bold]') + if alignment_target is None: + log.info('Alignment target is `None`') + else: + log.info(alignment_target) + if save_alignment_targets_to_path is not None: + alignment_target_path = os.path.join(save_alignment_targets_to_path, f"{alignment_target.id}.json") + + with open(alignment_target_path, "w") as f: + json.dump(alignment_target.to_dict(), f, indent=2) + + current_state = scenario.get_state() + scenario_complete = current_state.scenario_complete + + sce_times_s = [] + + last_scene_id = None + + treated_patients = [] + evac_patients = [] + + while not scenario_complete: + current_scene_id = current_state.meta_info.scene_id + if last_scene_id != current_scene_id: + log.info(f"[bold]*CHANGED SCENE TO*: {current_scene_id}[/bold]", + extra={"markup": True}) + last_scene_id = current_scene_id + + available_actions = scenario.get_available_actions() + + if sort_available_actions: + # Impose a fixed ordering of available actions to help + # with determinism + available_actions = sorted(available_actions, key=lambda a: a.unstructured) + + log.debug("[bold]*AVAILABLE ACTIONS*[/bold]", + extra={"markup": True}) + log.debug(json.dumps([a.to_dict() if hasattr(a, "to_dict") else a._asdict() for a in available_actions], indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + + if not self.apply_action_filtering: + available_actions_filtered = available_actions + else: + available_actions_filtered = [] + end_scene_idx = None + for idx, a in enumerate(available_actions): + if a.action_type == ActionTypeEnum.END_SCENE: + # We want to restrict end scene until all characters have been treated + end_scene_idx = idx + continue + + if a.action_type == ActionTypeEnum.TAG_CHARACTER: + # Don't let ADM choose to tag a character unless there are + # still untagged characters + untagged_characters = [c for c in current_state.characters + if c.tag is None and not c.unseen] + + available_actions_filtered.extend(self._expand_action_by_character( + action=a, + characters=untagged_characters, + restricted_character_ids=[], + )) + + if a.action_type == ActionTypeEnum.TREAT_PATIENT: + treatable_patients = [c for c in current_state.characters if not c.unseen] + + available_actions_filtered.extend(self._expand_action_by_character( + action=a, + characters=treatable_patients, + restricted_character_ids=treated_patients, + )) + + + if a.action_type == ActionTypeEnum.MOVE_TO_EVAC: + evacable_patients = [c for c in current_state.characters if not c.unseen] + + available_actions_filtered.extend(self._expand_action_by_character( + action=a, + characters=evacable_patients, + restricted_character_ids=evac_patients, + )) + + if len(available_actions_filtered) == 0: + if end_scene_idx is not None: # All patients have been tagged and treated + log.info("** All patients have been tagged and treated, ending scene") + action_to_take = available_actions[end_scene_idx] + action_to_take.justification = "All patients have been tagged and treated" + else: + raise RuntimeError("No available actions from filtered list!") + elif len(available_actions_filtered) == 1: + log.info("** Choosing only available (filtered) action") + action_to_take = available_actions_filtered[0] + action_to_take.justification = "Only available (filtered) action" + else: + start_choose_action = timer() + + try: + # Passing in a copy of available actions to + # prevent ADMs from modifying the originals (should + # considering doing the same for current_state and + # alignment_target) + choose_action_result = adm.choose_action( + current_state, + [deepcopy(a) for a in available_actions_filtered], + alignment_target if cfg.align_to_target else None, + scenario_id=scenario.id(), + **cfg.adm.get('inference_kwargs', {})) + + # Handle choose action result (for backwards compatibility if no choice_info) + if isinstance(choose_action_result, tuple): + action_to_take, choice_info = choose_action_result + if 'choice_info' in choice_info: + # Handle pipeline_adm + choice_info = choice_info['choice_info'] + else: + action_to_take = choose_action_result + choice_info = {} + + except SceneSkipException as e: + log.error(f"Scene skipped due to component failure: {e}") + log.info(f"Component {e.component_name} failed - choosing random action to advance scene") + + # Choose a random action from available_actions_filtered to advance the scenario + action_to_take = random.choice(available_actions_filtered) + action_to_take.justification = f"Random action chosen due to component failure: {e.component_name}" + choice_info = {} + + log.warning(f"Taking random action to advance: {action_to_take.action_type if hasattr(action_to_take, 'action_type') else 'unknown'}") + + # Common code for both success and exception paths + end_choose_action = timer() + sce_times_s.append(end_choose_action - start_choose_action) + log.debug(f"choose_action took {end_choose_action - start_choose_action} seconds") + + log.info("[bold]*ACTION BEING TAKEN*[/bold]", + extra={"markup": True}) + if isinstance(action_to_take, dict): + log.info(json.dumps(action_to_take, indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + else: + log.info(json.dumps(action_to_take.to_dict() if hasattr(action_to_take, "to_dict") else action_to_take._asdict(), indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + + action_choice_idx = None + for i, a in enumerate(available_actions): + if a.action_id == action_to_take.action_id: + action_choice_idx = i + break + + # Ensure that 'actions' stored in 'choice_info' are serializable + for info in choice_info.values(): + if 'action' in info: + info['action'] = info['action'].to_dict() + + inputs_outputs.append({'input': {'scenario_id': scenario.id(), + 'alignment_target_id': alignment_target.id if cfg.align_to_target else None, + 'full_state': current_state.to_dict() if hasattr(current_state, "to_dict") else current_state._asdict(), + 'state': current_state.unstructured, + 'choices': [a.to_dict() if hasattr(a, "to_dict") else a._asdict() for a in available_actions]}, + 'label': [{} if a.kdma_association is None else a.kdma_association for a in available_actions], + 'choice_info': choice_info, + 'output': {'choice': action_choice_idx, + 'action': action_to_take.to_dict() if hasattr(action_to_take, "to_dict") else action_to_take._asdict()}}) + # Save input_output after each action (gets overwritten + # each time) so that we don't lose everything if the run + # crashes or is interrupted. Could treat this as we do + # the logfile and open the file handle once and close + # `atexit` and write each line as it's generated (and make + # it a .jsonl file; would need to remove the indent=2) + if save_input_output_to_path is not None: + with open(save_input_output_to_path, 'w') as f: + json.dump(inputs_outputs, f, indent=2) + + try: + if hasattr(action_to_take, "intent_action") and action_to_take.intent_action: + current_state = scenario.intend_action(action_to_take) + else: + current_state = scenario.take_action(action_to_take) + except Exception as e: + log.info(action_to_take) + raise e + + # If we treated a patient, record that treatment so we can ensure we treat everyone + if action_to_take.action_type == ActionTypeEnum.TREAT_PATIENT: + treated_patients.append(action_to_take.character_id) + # If we evaced a patient, record that so we don't try to evac them again + if action_to_take.action_type == ActionTypeEnum.MOVE_TO_EVAC: + evac_patients.append(action_to_take.character_id) + + scenario_complete = current_state.scenario_complete + + if scenario_complete: + log.info("*Final state unstructured*: {}".format( + current_state.unstructured)) + + if cfg.get('save_last_unstructured_state_per_scenario', False): + if alignment_target is None: + scenario_alignment_target = scenario.get_alignment_target() + + if scenario_alignment_target is not None: + alignment_target_id = scenario_alignment_target.id + else: + alignment_target_id = None + else: + alignment_target_id = alignment_target.id + + final_scenario_state_output_path = os.path.join( + output_dir, "{}.{}.final_state_unstructured.json".format( + scenario.id(), alignment_target_id)) + with open(final_scenario_state_output_path, "w") as f: + print(current_state.unstructured, file=f) + + if save_timing_to_path is not None: + action_times["scenarios"].append(_compute_time_stats(sce_times_s)) + + if alignment_target is not None: + try: + session_alignment = interface.get_session_alignment( + alignment_target) + except Exception: + # Could be more specific about what kind of exceptions + # to expect here + session_alignment = None + + if session_alignment is None: + log.info("Couldn't get session alignment from interface") + else: + session_alignment_scores.append(session_alignment) + + if isinstance(session_alignment, dict): + session_alignment_dict = session_alignment + else: + session_alignment_dict = session_alignment.to_dict() + + log.info("[bold]*TA1 Alignment Score*[/bold]", + extra={"markup": True}) + log.info(json.dumps(session_alignment_dict, indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + + if save_timing_to_path is not None: + all_times = [] + for sce in action_times["scenarios"]: + all_times.extend(sce["raw_times_s"]) + + action_times.update(_compute_time_stats(all_times)) + + with open(save_timing_to_path, 'w') as f: + json.dump(action_times, f, indent=2) + + if len(session_alignment_scores) > 0: + if save_alignment_score_to_path is not None: + with open(save_alignment_score_to_path, 'w') as f: + json.dump([(s if isinstance(s, dict) else s.to_dict()) + for s in session_alignment_scores], f, indent=2) From dae984618c49090ac024433df9d9d11b79383b0d Mon Sep 17 00:00:00 2001 From: David Joy <10147749+dmjoy@users.noreply.github.com> Date: Wed, 13 May 2026 09:43:03 -0400 Subject: [PATCH 4/8] Minor TA3 interface tweaks for 0.5.3 client --- .../interfaces/ta3_caci_action_based_service.py | 14 ++++++++------ pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/align_system/interfaces/ta3_caci_action_based_service.py b/align_system/interfaces/ta3_caci_action_based_service.py index 32341c23..e2983fcc 100644 --- a/align_system/interfaces/ta3_caci_action_based_service.py +++ b/align_system/interfaces/ta3_caci_action_based_service.py @@ -121,9 +121,10 @@ def _take_or_intend_action(self, action, take_or_intend): session_id=self.session_id, action=action) - updated_state.unstructured = "{}\n{}".format( - updated_state.threat_state.unstructured, - updated_state.unstructured) + if updated_state.threat_state is not None: + updated_state.unstructured = "{}\n{}".format( + updated_state.threat_state.unstructured, + updated_state.unstructured) else: updated_state = take_or_intend( session_id=self.session_id, @@ -146,8 +147,9 @@ def get_state(self): session_id=self.session_id, scenario_id=self.scenario.id) if self.domain == "p2triage": - state.unstructured = "{}\n{}".format( - state.threat_state.unstructured, - state.unstructured) + if state.threat_state is not None: + state.unstructured = "{}\n{}".format( + state.threat_state.unstructured, + state.unstructured) return state diff --git a/pyproject.toml b/pyproject.toml index 6a742be7..d33afd89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "pytest>=9.0.2", "rouge-score>=0.1.2", "scikit-learn>=1.7.2", - "swagger-client==0.5.2", + "swagger-client==0.5.3", "transformers>=4.56.0,<5", # Pinned to be compatible with vllm 0.19.0 "ubelt>=1.4.1", "setuptools-scm>=8.0,<9", From 92c77c055ded0f545ccfb1aba74165a081b9086b Mon Sep 17 00:00:00 2001 From: David Joy <10147749+dmjoy@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:38:22 -0400 Subject: [PATCH 5/8] Add Random ADM support for 2026 OW scenarios --- .../algorithms/random_adm_component.py | 41 +++++++++++++++++++ ...ow_random_action_parameter_completion.yaml | 1 + .../phase2_feb_openworld/phase2_random.yaml | 30 ++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 align_system/configs/adm_component/misc/ow_random_action_parameter_completion.yaml create mode 100644 align_system/configs/experiment/phase2_feb_openworld/phase2_random.yaml diff --git a/align_system/algorithms/random_adm_component.py b/align_system/algorithms/random_adm_component.py index 7dc27591..7c3b5cf0 100644 --- a/align_system/algorithms/random_adm_component.py +++ b/align_system/algorithms/random_adm_component.py @@ -67,6 +67,9 @@ def run(self, # Action requires an aid ID elif chosen_action.action_type == ActionTypeEnum.MOVE_TO_EVAC: + if chosen_action.parameters is None: + chosen_action.parameters = {} + if "aid_id" not in chosen_action.parameters: chosen_action.parameters["aid_id"] = random.choice([ aid.id @@ -77,3 +80,41 @@ def run(self, chosen_action.justification = "Random choice" return chosen_action + + +class OWRandomParameterCompletionADMComponent(ADMComponent): + def run_returns(self): + return 'chosen_action' + + def run(self, + scenario_state, + choices, + actions, + chosen_choice, + chosen_action=None): + if chosen_action is None: + chosen_choice_idx = choices.index(chosen_choice) + chosen_action = actions[chosen_choice_idx] + + # Action requires a character ID + if chosen_action.action_type in {'TREAT_PATIENT', + ActionTypeEnum.MOVE_TO_EVAC, + ActionTypeEnum.TAG_CHARACTER}: + if chosen_action.character_id is None: + chosen_action.character_id = random.choice([ + c.id + for c in scenario_state.characters + if hasattr(c, "unseen") and not c.unseen + ]) + + if chosen_action.action_type == ActionTypeEnum.TAG_CHARACTER: + if chosen_action.parameters is None: + chosen_action.parameters = {} + + if 'category' not in chosen_action.parameters: + chosen_action.parameters['category'] = random.choice( + get_swagger_class_enum_values(CharacterTagEnum)) + + chosen_action.justification = "Random choice" + + return chosen_action diff --git a/align_system/configs/adm_component/misc/ow_random_action_parameter_completion.yaml b/align_system/configs/adm_component/misc/ow_random_action_parameter_completion.yaml new file mode 100644 index 00000000..b92be61f --- /dev/null +++ b/align_system/configs/adm_component/misc/ow_random_action_parameter_completion.yaml @@ -0,0 +1 @@ +_target_: align_system.algorithms.random_adm_component.OWRandomParameterCompletionADMComponent diff --git a/align_system/configs/experiment/phase2_feb_openworld/phase2_random.yaml b/align_system/configs/experiment/phase2_feb_openworld/phase2_random.yaml new file mode 100644 index 00000000..4a540bb5 --- /dev/null +++ b/align_system/configs/experiment/phase2_feb_openworld/phase2_random.yaml @@ -0,0 +1,30 @@ +# @package _global_ +defaults: + - override /adm: pipeline_random + - override /interface: ta3 + - override /driver: itm_phase2_ow + - override /adm_component/misc@adm.step_definitions.random_action_parameter_completion: ow_random_action_parameter_completion + +interface: + api_endpoint: 'http://127.0.0.1:8081' + session_type: eval + training_session: null + username: "testrun-pipeline_random" + domain: "p2triage" + +adm: + instance: + steps: + # Reference the step instances we want to use in order + - ${ref:adm.step_definitions.format_choices} + - ${ref:adm.step_definitions.random_choice} + - ${ref:adm.step_definitions.random_action_parameter_completion} + # - ${ref:adm.step_definitions.action_parameter_completion} + - ${ref:adm.step_definitions.ensure_chosen_action} + - ${ref:adm.step_definitions.populate_choice_info} + +driver: + apply_action_filtering: false + +force_determinism: true +align_to_target: false From 6fa5f7bc4d8897f0f75b9462811e7e3adbd3e9fd Mon Sep 17 00:00:00 2001 From: David Joy <10147749+dmjoy@users.noreply.github.com> Date: Tue, 9 Jun 2026 14:12:39 -0400 Subject: [PATCH 6/8] Separate action expansion and filtering in open world driver --- .../phase2_feb_openworld/phase2_random.yaml | 3 +- align_system/drivers/itm_open_world.py | 133 +++++++++++++----- 2 files changed, 96 insertions(+), 40 deletions(-) diff --git a/align_system/configs/experiment/phase2_feb_openworld/phase2_random.yaml b/align_system/configs/experiment/phase2_feb_openworld/phase2_random.yaml index 4a540bb5..b11dd3c5 100644 --- a/align_system/configs/experiment/phase2_feb_openworld/phase2_random.yaml +++ b/align_system/configs/experiment/phase2_feb_openworld/phase2_random.yaml @@ -24,7 +24,8 @@ adm: - ${ref:adm.step_definitions.populate_choice_info} driver: - apply_action_filtering: false + expand_actions: true + apply_action_filtering: true force_determinism: true align_to_target: false diff --git a/align_system/drivers/itm_open_world.py b/align_system/drivers/itm_open_world.py index 88b5cd59..1f0f797a 100644 --- a/align_system/drivers/itm_open_world.py +++ b/align_system/drivers/itm_open_world.py @@ -10,28 +10,32 @@ from swagger_client.models import ActionTypeEnum from timeit import default_timer as timer +from align_system.utils import get_swagger_class_enum_values from align_system.utils import logging from align_system.utils.version import get_version from align_system.exceptions import SceneSkipException +from align_system.data_models.compat.ta3_ph1_client_models import ( + CharacterTagEnum) + log = logging.getLogger(__name__) JSON_HIGHLIGHTER = JSONHighlighter() +DEFAULT_TAGS = get_swagger_class_enum_values(CharacterTagEnum) class ITMOpenWorldDriver: def __init__(self, apply_action_filtering=True, + expand_actions=False, sort_available_actions=False): self.apply_action_filtering = apply_action_filtering + self.expand_actions = expand_actions self.sort_available_actions = sort_available_actions - def _expand_action_by_character(self, action, characters, restricted_character_ids): + def _expand_action_by_character(self, action, characters): expanded_actions = [] for character in characters: - if character.id in restricted_character_ids: - continue - new_action = deepcopy(action) new_action.character_id = character.id new_action.unstructured = re.sub(r"(a )?Patient", character.name, action.unstructured) @@ -40,6 +44,28 @@ def _expand_action_by_character(self, action, characters, restricted_character_i return expanded_actions + def _expand_action_by_tag(self, action, possible_tags=DEFAULT_TAGS): + assert action.action_type == ActionTypeEnum.TAG_CHARACTER + + expanded_actions = [] + + for tag in possible_tags: + new_action = deepcopy(action) + if new_action.parameters is None: + new_action.parameters = {} + + new_action.parameters['category'] = tag + if re.match(r'^[aeiou]', tag, re.I): + prefix = "an" + else: + prefix = "a" + + new_action.unstructured = re.sub(r"a triage tag", f"{prefix} {tag} triage tag", action.unstructured) + + expanded_actions.append(new_action) + + return expanded_actions + def drive(self, cfg): interface = cfg.interface @@ -148,8 +174,8 @@ def _compute_time_stats(times_s): last_scene_id = None - treated_patients = [] - evac_patients = [] + treated_patients = set() + evac_patients = set() while not scenario_complete: current_scene_id = current_state.meta_info.scene_id @@ -170,47 +196,80 @@ def _compute_time_stats(times_s): log.debug(json.dumps([a.to_dict() if hasattr(a, "to_dict") else a._asdict() for a in available_actions], indent=4), extra={"highlighter": JSON_HIGHLIGHTER}) + if not self.expand_actions: + available_actions_expanded = available_actions + if self.expand_actions: + available_actions_expanded = [] + for idx, a in enumerate(available_actions): + if a.action_type == ActionTypeEnum.TAG_CHARACTER: + # Expanding twice here, once for + # characters, and again for possible tags + for char_expanded_action in self._expand_action_by_character( + action=a, + characters=current_state.characters): + available_actions_expanded.extend(self._expand_action_by_tag( + action=char_expanded_action)) + + elif a.action_type == ActionTypeEnum.TREAT_PATIENT: + available_actions_expanded.extend(self._expand_action_by_character( + action=a, + characters=current_state.characters + )) + + + elif a.action_type == ActionTypeEnum.MOVE_TO_EVAC: + available_actions_expanded.extend(self._expand_action_by_character( + action=a, + characters=current_state.characters + )) + + else: + available_actions_expanded.append(a) + + log.debug("[bold]*AVAILABLE ACTIONS EXPANDED*[/bold]", + extra={"markup": True}) + log.debug(json.dumps([a.to_dict() if hasattr(a, "to_dict") else a._asdict() for a in available_actions_expanded], indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + + if not self.apply_action_filtering: - available_actions_filtered = available_actions + available_actions_filtered = available_actions_expanded else: available_actions_filtered = [] end_scene_idx = None - for idx, a in enumerate(available_actions): + + untagged_characters = {c.id for c in current_state.characters + if c.tag is None and not c.unseen} + # HACK: Current TA3 server doesn't track what + # patients have been treated or evac'd (via + # c.unseen, or any other means); need to track it + # manually + # treatable_patients = {c.id for c in current_state.characters if not c.unseen} + # evacable_patients = {c.id for c in current_state.characters if not c.unseen} + treatable_patients = {c.id for c in current_state.characters if c.id not in treated_patients} + evacable_patients = {c.id for c in current_state.characters if c.id not in evac_patients} + + for idx, a in enumerate(available_actions_expanded): if a.action_type == ActionTypeEnum.END_SCENE: # We want to restrict end scene until all characters have been treated end_scene_idx = idx continue - if a.action_type == ActionTypeEnum.TAG_CHARACTER: + elif a.action_type == ActionTypeEnum.TAG_CHARACTER: # Don't let ADM choose to tag a character unless there are # still untagged characters - untagged_characters = [c for c in current_state.characters - if c.tag is None and not c.unseen] + if a.character_id not in untagged_characters: + continue - available_actions_filtered.extend(self._expand_action_by_character( - action=a, - characters=untagged_characters, - restricted_character_ids=[], - )) - - if a.action_type == ActionTypeEnum.TREAT_PATIENT: - treatable_patients = [c for c in current_state.characters if not c.unseen] + elif a.action_type == ActionTypeEnum.TREAT_PATIENT: + if a.character_id not in treatable_patients: + continue - available_actions_filtered.extend(self._expand_action_by_character( - action=a, - characters=treatable_patients, - restricted_character_ids=treated_patients, - )) + elif a.action_type == ActionTypeEnum.MOVE_TO_EVAC: + if a.character_id not in evacable_patients: + continue - - if a.action_type == ActionTypeEnum.MOVE_TO_EVAC: - evacable_patients = [c for c in current_state.characters if not c.unseen] - - available_actions_filtered.extend(self._expand_action_by_character( - action=a, - characters=evacable_patients, - restricted_character_ids=evac_patients, - )) + available_actions_filtered.append(a) if len(available_actions_filtered) == 0: if end_scene_idx is not None: # All patients have been tagged and treated @@ -219,10 +278,6 @@ def _compute_time_stats(times_s): action_to_take.justification = "All patients have been tagged and treated" else: raise RuntimeError("No available actions from filtered list!") - elif len(available_actions_filtered) == 1: - log.info("** Choosing only available (filtered) action") - action_to_take = available_actions_filtered[0] - action_to_take.justification = "Only available (filtered) action" else: start_choose_action = timer() @@ -314,10 +369,10 @@ def _compute_time_stats(times_s): # If we treated a patient, record that treatment so we can ensure we treat everyone if action_to_take.action_type == ActionTypeEnum.TREAT_PATIENT: - treated_patients.append(action_to_take.character_id) + treated_patients.add(action_to_take.character_id) # If we evaced a patient, record that so we don't try to evac them again if action_to_take.action_type == ActionTypeEnum.MOVE_TO_EVAC: - evac_patients.append(action_to_take.character_id) + evac_patients.add(action_to_take.character_id) scenario_complete = current_state.scenario_complete From 4de38e09fcd645369b9dc8b29a06b5935d72e19e Mon Sep 17 00:00:00 2001 From: David Joy <10147749+dmjoy@users.noreply.github.com> Date: Tue, 9 Jun 2026 14:17:35 -0400 Subject: [PATCH 7/8] Add open world baseline config --- .../phase2_feb_openworld/phase2_baseline.yaml | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 align_system/configs/experiment/phase2_feb_openworld/phase2_baseline.yaml diff --git a/align_system/configs/experiment/phase2_feb_openworld/phase2_baseline.yaml b/align_system/configs/experiment/phase2_feb_openworld/phase2_baseline.yaml new file mode 100644 index 00000000..5665f4dc --- /dev/null +++ b/align_system/configs/experiment/phase2_feb_openworld/phase2_baseline.yaml @@ -0,0 +1,39 @@ +# @package _global_ +defaults: + - override /adm: pipeline_baseline + - override /inference_engine@adm.structured_inference_engine: outlines_structured_greedy + - override /interface: ta3 + - override /driver: itm_phase2_ow + +interface: + api_endpoint: 'http://127.0.0.1:8081' + session_type: eval + training_session: null + username: "testrun-ALIGN-ADM-OutlinesBaseline-Mistral-7B-Instruct-v0.3" + domain: "p2triage" + +adm: + step_definitions: + outlines_baseline: + scenario_description_template: + _target_: align_system.prompt_engineering.outlines_prompts.Phase2ScenarioDescription + prompt_template: + _target_: align_system.prompt_engineering.outlines_prompts.Phase2BaselinePrompt + + enable_caching: true + + instance: + steps: + # Reference the step instances we want to use in order + - ${ref:adm.step_definitions.format_choices} + - ${ref:adm.step_definitions.outlines_baseline} + # - ${ref:adm.step_definitions.action_parameter_completion} + - ${ref:adm.step_definitions.ensure_chosen_action} + - ${ref:adm.step_definitions.populate_choice_info} + +driver: + expand_actions: true + apply_action_filtering: true + +force_determinism: true +align_to_target: false From c9c6c8552db7506145e2748f0d33d3a0cc8d5e44 Mon Sep 17 00:00:00 2001 From: David Joy <10147749+dmjoy@users.noreply.github.com> Date: Tue, 9 Jun 2026 15:42:43 -0400 Subject: [PATCH 8/8] Tweaks for open world comparative regression --- .../algorithms/alignment_adm_component.py | 2 +- .../multinomial_random_effects_tuple.yaml | 3 ++ ...comparative_regression_random_effects.yaml | 37 +++++++++++++++++++ .../phase2_w_casualty_info.yaml | 1 + .../prompt_engineering/outlines_prompts.py | 18 ++++++++- 5 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 align_system/configs/adm_component/alignment/multinomial_random_effects_tuple.yaml create mode 100644 align_system/configs/experiment/phase2_feb_openworld/phase2_pipeline_fewshot_comparative_regression_random_effects.yaml create mode 100644 align_system/configs/template/scenario_description/phase2_w_casualty_info.yaml diff --git a/align_system/algorithms/alignment_adm_component.py b/align_system/algorithms/alignment_adm_component.py index 14a4fd50..515312c0 100644 --- a/align_system/algorithms/alignment_adm_component.py +++ b/align_system/algorithms/alignment_adm_component.py @@ -572,7 +572,7 @@ def run( alignment_info = { "source": type(self).__name__, - "p_choices": p_choices, + "p_choices": p_choices.tolist(), } max_idx = np.argmax(p_choices) diff --git a/align_system/configs/adm_component/alignment/multinomial_random_effects_tuple.yaml b/align_system/configs/adm_component/alignment/multinomial_random_effects_tuple.yaml new file mode 100644 index 00000000..a0f7f139 --- /dev/null +++ b/align_system/configs/adm_component/alignment/multinomial_random_effects_tuple.yaml @@ -0,0 +1,3 @@ +_target_: align_system.algorithms.alignment_adm_component.MultinomialRandomEffectsModelAlignmentADMComponent + +attributes: ${ref:adm.attribute_definitions} diff --git a/align_system/configs/experiment/phase2_feb_openworld/phase2_pipeline_fewshot_comparative_regression_random_effects.yaml b/align_system/configs/experiment/phase2_feb_openworld/phase2_pipeline_fewshot_comparative_regression_random_effects.yaml new file mode 100644 index 00000000..054506c8 --- /dev/null +++ b/align_system/configs/experiment/phase2_feb_openworld/phase2_pipeline_fewshot_comparative_regression_random_effects.yaml @@ -0,0 +1,37 @@ +# @package _global_ +defaults: + - override /adm: phase2_pipeline_fewshot_comparative_regression_bert_relevance + - override /interface: ta3 + - override /adm_component/alignment@adm.step_definitions.scalar_alignment: multinomial_random_effects_tuple + - override /template/scenario_description@adm.scenario_description_template: phase2_w_casualty_info + - override /driver: itm_phase2_ow + +interface: + api_endpoint: 'http://127.0.0.1:8081' + session_type: eval + training_session: null + username: "testrun-ALIGN-ADM-Ph2-ComparativeRegression-BertRelevance-Mistral-7B-Instruct-v0.3" + domain: "p2triage" + +adm: + step_definitions: + regression_icl: + icl_generator_partial: + incontext_settings: + number: 20 + datasets: + medical: /data/shared/samba/phase2_icl/Feb2026-MU-train_20251218.json + affiliation: /data/shared/samba/phase2_icl/Feb2026-AF-train_20251218.json + merit: /data/shared/samba/phase2_icl/Feb2026-MF-train_20251218.json + personal_safety: /data/shared/samba/phase2_icl/Feb2026-PS-train_20251218.json + search: /data/shared/samba/phase2_icl/Feb2026-SS-train_20251218.json + enable_caching: true + comparative_regression: + enable_caching: true + +driver: + expand_actions: true + apply_action_filtering: true + +force_determinism: true +align_to_target: true diff --git a/align_system/configs/template/scenario_description/phase2_w_casualty_info.yaml b/align_system/configs/template/scenario_description/phase2_w_casualty_info.yaml new file mode 100644 index 00000000..6da00e66 --- /dev/null +++ b/align_system/configs/template/scenario_description/phase2_w_casualty_info.yaml @@ -0,0 +1 @@ +_target_: align_system.prompt_engineering.outlines_prompts.Phase2ScenarioDescriptionWCasualtyInfo diff --git a/align_system/prompt_engineering/outlines_prompts.py b/align_system/prompt_engineering/outlines_prompts.py index e5c59c0b..080b9889 100644 --- a/align_system/prompt_engineering/outlines_prompts.py +++ b/align_system/prompt_engineering/outlines_prompts.py @@ -1215,6 +1215,22 @@ def __call__(self, scenario_state): return phase2_scenario_state_description(scenario_state) +@compat_outlines_prompt +def phase2_scenario_state_description_w_casualty_info(scenario_state): + """ + {{ scenario_state.unstructured.rstrip() }} + + Casualties: + {% for c in scenario_state.characters %} + - {{ c.name }}: {{ c.unstructured }} + {% endfor %} + """ + +class Phase2ScenarioDescriptionWCasualtyInfo(): + def __call__(self, scenario_state): + return phase2_scenario_state_description_w_casualty_info(scenario_state) + + @compat_outlines_prompt def phase2_baseline_prompt(scenario_description, choices): """ @@ -1535,4 +1551,4 @@ def __call__(self): assert self.max_value == 100 return r'\{{"reasoning":\s"[^"]{{0,{}}}",\s"score":\s(\d|\d{{2}}|100)\}}'.format( - self.max_reasoning_length) \ No newline at end of file + self.max_reasoning_length)