From 14ec005acf8da429deacfc35327a1e2be344d9ba Mon Sep 17 00:00:00 2001 From: isty2e Date: Thu, 2 Jul 2026 22:58:49 +0900 Subject: [PATCH] fix(csa): preserve reference bank after adaptive growth --- .../population/csa/banking/reference.py | 29 ++++++- .../population/csa/engine/boundary.py | 6 +- tests/csa/test_csa_banking.py | 78 ++++++++++++++++++- 3 files changed, 106 insertions(+), 7 deletions(-) diff --git a/src/variopt/algorithms/population/csa/banking/reference.py b/src/variopt/algorithms/population/csa/banking/reference.py index 19d13b97..c227666a 100644 --- a/src/variopt/algorithms/population/csa/banking/reference.py +++ b/src/variopt/algorithms/population/csa/banking/reference.py @@ -9,7 +9,7 @@ from variopt.generic_runtime import FrozenGenericSlotsCompat from .....artifacts import Observation -from .....json_types import JSONDict, JSONValue +from .....json_types import JSONDict, JSONValue, require_json_bool from .....typevars import CandidateT from .bank import Bank, BankEntry @@ -24,10 +24,13 @@ class ReferenceBank(FrozenGenericSlotsCompat, Generic[CandidateT]): Maximum number of entries the reference bank may store. entries : tuple[BankEntry[CandidateT], ...], optional Entries currently stored in the reference bank. + initialized : bool, default=False + Whether the entries represent a completed reference snapshot. """ capacity: int entries: tuple[BankEntry[CandidateT], ...] = () + initialized: bool = False def __post_init__(self) -> None: """Validate reference-bank state. @@ -45,6 +48,13 @@ def __post_init__(self) -> None: msg = "entries must not exceed capacity" raise ValueError(msg) + if self.initialized and not self.is_full: + msg = "initialized reference banks must be full" + raise ValueError(msg) + + if not self.initialized and self.is_full: + object.__setattr__(self, "initialized", True) + @property def is_full(self) -> bool: """Return whether the reference bank has reached capacity. @@ -79,6 +89,7 @@ def to_dict( entry.to_dict(candidate_to_dict=candidate_to_dict) for entry in self.entries ], + "initialized": self.initialized, } @classmethod @@ -110,12 +121,21 @@ def from_dict( """ capacity = data.get("capacity") raw_entries = data.get("entries") + raw_initialized = data.get("initialized") if not isinstance(capacity, int): msg = "reference bank snapshot requires integer capacity" raise TypeError(msg) if not isinstance(raw_entries, list): msg = "reference bank snapshot requires entry list" raise TypeError(msg) + initialized = ( + None + if raw_initialized is None + else require_json_bool( + raw_initialized, + field_name="reference bank initialized", + ) + ) entries: list[BankEntry[CandidateT]] = [] for raw_entry in raw_entries: @@ -132,6 +152,11 @@ def from_dict( return cls( capacity=capacity, entries=tuple(entries), + initialized=( + len(entries) >= capacity + if initialized is None + else initialized + ), ) @@ -287,6 +312,7 @@ def build_reference_bank_from_bank( return ReferenceBank( capacity=bank.capacity, entries=sorted_entries, + initialized=True, ) @@ -351,6 +377,7 @@ def build_reference_bank_from_refresh_pool( return ReferenceBank( capacity=capacity, entries=preserved_entry_tuple + selected_entries, + initialized=True, ) diff --git a/src/variopt/algorithms/population/csa/engine/boundary.py b/src/variopt/algorithms/population/csa/engine/boundary.py index 57505d2d..7e7b8548 100644 --- a/src/variopt/algorithms/population/csa/engine/boundary.py +++ b/src/variopt/algorithms/population/csa/engine/boundary.py @@ -296,11 +296,7 @@ def sync_reference_bank_if_uninitialized( if not engine_state.banking_state.bank.is_full: return engine_state - if ( - engine_state.banking_state.reference_bank.is_full - and len(engine_state.banking_state.reference_bank.entries) - == len(engine_state.banking_state.bank.entries) - ): + if engine_state.banking_state.reference_bank.initialized: return engine_state bank = build_sorted_bank_from_bank(engine_state.banking_state.bank) diff --git a/tests/csa/test_csa_banking.py b/tests/csa/test_csa_banking.py index dde2a1a4..416a8e04 100644 --- a/tests/csa/test_csa_banking.py +++ b/tests/csa/test_csa_banking.py @@ -1,5 +1,6 @@ """Tests for CSA banking, score-model, and admission semantics.""" +from dataclasses import replace from typing import Literal, cast import pytest @@ -36,12 +37,18 @@ make_optimizer, significant_update_indices, ) +from variopt.algorithms.population.csa.banking.clustering import ( + CSAClusteringPolicy, + CSAClusteringState, +) from variopt.algorithms.population.csa.banking.queries import ( BankDistanceWorkspace, best_mean_niche_scores, crowding_aware_scores, ) from variopt.algorithms.population.csa.scoring.model_state import CSAScoreModelState +from variopt.algorithms.population.csa.selection.state import SeedSelectionState +from variopt.json_types import JSONValue class CountingDistance(DiversityMetric[int]): @@ -730,6 +737,36 @@ def test_admit_rejects_negative_diversity_distance(self) -> None: class CSABankingTests(CSAOptimizerTestCase): """White-box tests for CSA banking and score-model state transitions.""" + def test_reference_bank_legacy_snapshot_derives_initialized_state(self) -> None: + def candidate_from_dict(value: JSONValue) -> int: + if type(value) is not int: + msg = "candidate must be an integer" + raise TypeError(msg) + return value + + full_reference = ReferenceBank[int].from_dict( + { + "capacity": 2, + "entries": [ + {"candidate": 0, "value": 0.0, "proposal_id": "b-0"}, + {"candidate": 1, "value": 1.0, "proposal_id": "b-1"}, + ], + }, + candidate_from_dict=candidate_from_dict, + ) + partial_reference = ReferenceBank[int].from_dict( + { + "capacity": 2, + "entries": [ + {"candidate": 0, "value": 0.0, "proposal_id": "b-0"}, + ], + }, + candidate_from_dict=candidate_from_dict, + ) + + assert full_reference.initialized + assert not partial_reference.initialized + def test_significant_update_threshold_ignores_small_score_changes(self) -> None: optimizer = make_optimizer( space=IntegerSpace(low=0, high=100), @@ -1052,6 +1089,29 @@ def test_growth_policy_can_append_far_candidate(self) -> None: ), distance_cutoff=2.0, ) + optimizer.engine_state = replace( + optimizer.engine_state, + banking_state=replace( + optimizer.engine_state.banking_state, + clustering_state=CSAClusteringState[int]( + policy=CSAClusteringPolicy(enabled=True), + cluster_distance=2.0, + cluster_labels=(1, 2), + ), + ), + selection_state=SeedSelectionState( + used_entry_indices=frozenset({1}), + bank_status=(False, True), + ), + progression_state=replace( + optimizer.engine_state.progression_state, + stage_state=optimizer.engine_state.progression_state.stage_state.with_masks( + seed_mask=frozenset({1}), + partner_mask=frozenset({0}), + ), + ).with_refresh_mask(frozenset({1})), + ) + reference_entries = optimizer.reference_bank.entries proposal = Proposal(candidate=20, proposal_id="p-1") optimizer.pending_by_id = {"p-1": proposal} @@ -1068,8 +1128,23 @@ def test_growth_policy_can_append_far_candidate(self) -> None: assert optimizer.bank.capacity == 3 assert ( - tuple(entry.value for entry in optimizer.bank.entries) == (0.0, 5.0, 10.0) + tuple(entry.candidate for entry in optimizer.bank.entries) == (0, 10, 20) + ) + assert optimizer.reference_bank.entries == reference_entries + assert ( + tuple( + zip( + (entry.candidate for entry in optimizer.bank.entries), + optimizer.engine_state.banking_state.clustering_state.cluster_labels, + strict=True, + ) + ) + == ((0, 1), (10, 2), (20, 3)) ) + assert optimizer.selection_state.used_entry_indices == frozenset({1}) + assert optimizer.selection_state.bank_status == (False, True, False) + assert optimizer.engine_state.progression_state.seed_mask == frozenset({1}) + assert optimizer.engine_state.progression_state.partner_mask == frozenset({0, 1}) def test_growth_policy_reduces_oversized_bank_after_batch(self) -> None: optimizer = make_optimizer( @@ -1149,3 +1224,4 @@ def test_initial_fill_sorts_bank_and_reference_bank_by_score(self) -> None: tuple(entry.candidate for entry in optimizer.reference_bank.entries) == (1, 5, 9) ) + assert optimizer.reference_bank.initialized