From f6f6af6d367dc43e5890313aa8c936751fed7e62 Mon Sep 17 00:00:00 2001 From: Vincenzo DiMatteo <47278634+Vman11@users.noreply.github.com> Date: Wed, 10 Jun 2026 14:45:24 -0400 Subject: [PATCH] able to config bert scorer for gpu/cpu, as well as cache it --- .../incontext.yaml | 2 + .../incontext.yaml | 2 + .../incontext_phase1.yaml | 2 + .../incontext.yaml | 4 ++ align_system/utils/incontext_utils.py | 68 +++++++++++++++---- 5 files changed, 66 insertions(+), 12 deletions(-) diff --git a/align_system/configs/adm/outlines_regression_aligned/incontext.yaml b/align_system/configs/adm/outlines_regression_aligned/incontext.yaml index 4408eeda..3093f4af 100644 --- a/align_system/configs/adm/outlines_regression_aligned/incontext.yaml +++ b/align_system/configs/adm/outlines_regression_aligned/incontext.yaml @@ -6,6 +6,8 @@ inference_kwargs: incontext: number: 5 method: bert_similarity + bert_scorer_device: auto + cache_bert_scorer: true datasets: MoralDesert: /data/shared/samba/integrated_results_metrics_eval/captured_dataset_for_chris/baseline_adept_high-1715105775-input-output.json maximization: /data/shared/samba/integrated_results_metrics_eval/captured_dataset_for_chris/baseline_soartech_high-1716581856-input-output.json diff --git a/align_system/configs/adm/outlines_regression_aligned_comparative/incontext.yaml b/align_system/configs/adm/outlines_regression_aligned_comparative/incontext.yaml index 56c07ab8..4ad1a62e 100644 --- a/align_system/configs/adm/outlines_regression_aligned_comparative/incontext.yaml +++ b/align_system/configs/adm/outlines_regression_aligned_comparative/incontext.yaml @@ -6,6 +6,8 @@ inference_kwargs: incontext: number: 5 method: prompt_bert_similarity + bert_scorer_device: auto + cache_bert_scorer: true leave_one_out_strategy: null normalization: globalnorm datasets: diff --git a/align_system/configs/adm/outlines_regression_aligned_comparative/incontext_phase1.yaml b/align_system/configs/adm/outlines_regression_aligned_comparative/incontext_phase1.yaml index 627b38ef..c416224a 100644 --- a/align_system/configs/adm/outlines_regression_aligned_comparative/incontext_phase1.yaml +++ b/align_system/configs/adm/outlines_regression_aligned_comparative/incontext_phase1.yaml @@ -6,6 +6,8 @@ inference_kwargs: incontext: number: 5 method: prompt_bert_similarity + bert_scorer_device: auto + cache_bert_scorer: true leave_one_out_strategy: null normalization: null sort_actions: true diff --git a/align_system/configs/adm/outlines_transformers_structured_aligned/incontext.yaml b/align_system/configs/adm/outlines_transformers_structured_aligned/incontext.yaml index bf9c0566..119fa05f 100644 --- a/align_system/configs/adm/outlines_transformers_structured_aligned/incontext.yaml +++ b/align_system/configs/adm/outlines_transformers_structured_aligned/incontext.yaml @@ -5,6 +5,10 @@ inference_kwargs: incontext: number: 5 method: scenario_bert_similarity + # Device for the BERT scorer: 'auto' (cuda if available), 'cpu', or 'cuda' + bert_scorer_device: auto + # Reuse a single BERT scorer instance across calls (avoids reloading the model) + cache_bert_scorer: true datasets: MoralDesert: /data/shared/samba/integrated_results_metrics_eval/captured_dataset_for_chris/baseline_adept_high-1715105775-input-output.json maximization: /data/shared/samba/integrated_results_metrics_eval/captured_dataset_for_chris/baseline_soartech_high-1716581856-input-output.json diff --git a/align_system/utils/incontext_utils.py b/align_system/utils/incontext_utils.py index 965b83d4..7912ecef 100644 --- a/align_system/utils/incontext_utils.py +++ b/align_system/utils/incontext_utils.py @@ -4,7 +4,7 @@ import random import numpy as np from abc import ABCMeta, abstractmethod -from bert_score import score as bert_score +from bert_score import BERTScorer from omegaconf import ListConfig, OmegaConf from align_system.utils import adm_utils @@ -25,7 +25,38 @@ ) -def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_examples, score_adjustments=None, least_similar_examples=False): +_bert_scorer = None +_bert_scorer_device = None + + +def get_bert_scorer(device="auto", cache_scorer=True): + """Build (or fetch the cached) BERTScorer instance. + + Args: + device: Device to run the scorer on ('cpu', 'cuda', etc.). + 'auto' (or None) selects 'cuda' when available, else 'cpu' + cache_scorer: If True, reuse a single scorer instance across calls + to avoid reloading the model + + Returns: + BERTScorer instance on the requested device + """ + global _bert_scorer, _bert_scorer_device + + if device is None or device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + + if not cache_scorer: + return BERTScorer(lang="en", device=device) + + if _bert_scorer is None or _bert_scorer_device != device: + _bert_scorer = BERTScorer(lang="en", device=device) + _bert_scorer_device = device + + return _bert_scorer + + +def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_examples, score_adjustments=None, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True): """Common BERT similarity selection logic for all strategies. Args: @@ -36,11 +67,14 @@ def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_ex score_adjustments: Optional list of score adjustments (same length as candidates) least_similar_examples: If True, selects least similar examples to approximate domain shift between train and eval on train data only + bert_scorer_device: Device to run the BERT scorer on ('auto', 'cpu', 'cuda', etc.) + cache_bert_scorer: If True, reuse a single BERTScorer instance across calls Returns: List of selected candidates with 'similarity_score' field added """ - _, _, scores = bert_score([reference_text] * len(texts_to_compare), texts_to_compare, lang="en") + scorer = get_bert_scorer(device=bert_scorer_device, cache_scorer=cache_bert_scorer) + _, _, scores = scorer.score([reference_text] * len(texts_to_compare), texts_to_compare) if score_adjustments is not None: for i, adjustment in enumerate(score_adjustments): @@ -72,7 +106,7 @@ def select_random_strategy(possible_examples, n_examples, **kwargs): return selected_with_scores -def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scenario_to_match, least_similar_examples=False, **kwargs): +def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scenario_to_match, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs): """Scenario-based BERT similarity selection strategy""" final_candidates = list({ex['scenario_description']: ex for ex in possible_examples}.values()) possible_scenarios = [icl_sample["scenario_description"] for icl_sample in final_candidates] @@ -82,11 +116,13 @@ def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scen possible_scenarios, scenario_to_match, n_examples, - least_similar_examples=least_similar_examples + least_similar_examples=least_similar_examples, + bert_scorer_device=bert_scorer_device, + cache_bert_scorer=cache_bert_scorer ) -def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt_to_match, least_similar_examples=False, **kwargs): +def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt_to_match, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs): """Prompt-based BERT similarity selection strategy""" final_candidates = list({ex['prompt']: ex for ex in possible_examples}.values()) possible_prompts = [icl_sample["prompt"] for icl_sample in final_candidates] @@ -96,11 +132,13 @@ def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt possible_prompts, prompt_to_match, n_examples, - least_similar_examples=least_similar_examples + least_similar_examples=least_similar_examples, + bert_scorer_device=bert_scorer_device, + cache_bert_scorer=cache_bert_scorer ) -def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, **kwargs): +def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs): """Action-matching with BERT similarity selection strategy""" action_types = set([action.action_type for action in actions]) possible_prompts = [icl_sample["prompt"] for icl_sample in possible_examples] @@ -119,11 +157,13 @@ def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_ma prompt_to_match, n_examples, score_adjustments, - least_similar_examples=least_similar_examples + least_similar_examples=least_similar_examples, + bert_scorer_device=bert_scorer_device, + cache_bert_scorer=cache_bert_scorer ) -def select_matching_characters_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, **kwargs): +def select_matching_characters_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs): """Character-matching with BERT similarity selection strategy""" action_chars = set([action.character_id for action in actions]) possible_prompts = [icl_sample["prompt"] for icl_sample in possible_examples] @@ -142,7 +182,9 @@ def select_matching_characters_strategy(possible_examples, n_examples, prompt_to prompt_to_match, n_examples, score_adjustments, - least_similar_examples=least_similar_examples + least_similar_examples=least_similar_examples, + bert_scorer_device=bert_scorer_device, + cache_bert_scorer=cache_bert_scorer ) @@ -460,7 +502,9 @@ def select_icl_examples(self, sys_kdma_name, scenario_description_to_match, prom scenario_to_match=scenario_description_to_match, prompt_to_match=prompt_to_match, actions=actions, - least_similar_examples=least_similar_examples + least_similar_examples=least_similar_examples, + bert_scorer_device=self.incontext_settings.get("bert_scorer_device", "auto"), + cache_bert_scorer=self.incontext_settings.get("cache_bert_scorer", True) ) if self.incontext_settings.get("most_similar_first", True):