From 837cc82921c1d5d2c267612b0884fda8be771f5b Mon Sep 17 00:00:00 2001 From: isty2e Date: Fri, 3 Jul 2026 05:08:23 +0900 Subject: [PATCH] fix(csa): keep index-aligned state synced --- .../csa/banking/clustering/state.py | 2 +- .../population/csa/banking/update/logic.py | 18 +- .../population/csa/banking/update/result.py | 5 +- .../algorithms/population/csa/engine/tell.py | 49 +++--- .../algorithms/population/csa/indexing.py | 57 +++++++ .../population/csa/progression/stage.py | 46 +++++ .../population/csa/progression/state.py | 47 ++++++ .../population/csa/selection/state.py | 23 +-- tests/csa/test_csa_banking.py | 157 ++++++++++++++++++ tests/csa/test_csa_clustering.py | 49 ++++++ tests/csa/test_csa_engine_state.py | 19 +++ 11 files changed, 419 insertions(+), 53 deletions(-) create mode 100644 src/variopt/algorithms/population/csa/indexing.py diff --git a/src/variopt/algorithms/population/csa/banking/clustering/state.py b/src/variopt/algorithms/population/csa/banking/clustering/state.py index 2789654..ddb50f7 100644 --- a/src/variopt/algorithms/population/csa/banking/clustering/state.py +++ b/src/variopt/algorithms/population/csa/banking/clustering/state.py @@ -418,7 +418,7 @@ def register_admission( elif appended: admitted_label = 1 if not labels else labels[nearest_index] else: - admitted_label = labels[admitted_index] + admitted_label = labels[nearest_index] if appended: labels.append(admitted_label) diff --git a/src/variopt/algorithms/population/csa/banking/update/logic.py b/src/variopt/algorithms/population/csa/banking/update/logic.py index 0ad4b76..44b370b 100644 --- a/src/variopt/algorithms/population/csa/banking/update/logic.py +++ b/src/variopt/algorithms/population/csa/banking/update/logic.py @@ -180,6 +180,15 @@ def apply_bank_update_batch( significant_update_indices=batch_significant_update_indices, ) + final_changed_indices = changed_indices( + previous_bank=previous_bank, + next_bank=shadow_bank, + ) + final_significant_update_indices = significant_update_indices( + previous_bank=previous_bank, + next_bank=shadow_bank, + minimum_significant_score_gap=update_policy.minimum_significant_score_gap, + ) removed_indices: frozenset[int] = frozenset() ( shadow_bank, @@ -200,15 +209,6 @@ def apply_bank_update_batch( entries=shadow_bank.entries, diversity_metric=diversity_metric, ) - final_changed_indices = changed_indices( - previous_bank=previous_bank, - next_bank=shadow_bank, - ) - final_significant_update_indices = significant_update_indices( - previous_bank=previous_bank, - next_bank=shadow_bank, - minimum_significant_score_gap=update_policy.minimum_significant_score_gap, - ) return BankUpdateResult( bank=shadow_bank, state=shadow_state, diff --git a/src/variopt/algorithms/population/csa/banking/update/result.py b/src/variopt/algorithms/population/csa/banking/update/result.py index 1686362..c72e428 100644 --- a/src/variopt/algorithms/population/csa/banking/update/result.py +++ b/src/variopt/algorithms/population/csa/banking/update/result.py @@ -33,9 +33,10 @@ class BankUpdateResult(FrozenGenericSlotsCompat, Generic[CandidateT]): trace_state : CSAEventTraceState[CandidateT] | None Optional updated trace reducer state. changed_indices : frozenset[int] - Bank indices whose entries changed at all. + Bank indices whose entries changed before post-batch removals are applied. significant_update_indices : frozenset[int] - Bank indices whose score change exceeded the significance floor. + Changed bank indices whose score gap exceeded the significance floor + before post-batch removals are applied. removed_indices : frozenset[int] Bank indices removed from the previous bank snapshot. """ diff --git a/src/variopt/algorithms/population/csa/engine/tell.py b/src/variopt/algorithms/population/csa/engine/tell.py index eb8d565..dbc1261 100644 --- a/src/variopt/algorithms/population/csa/engine/tell.py +++ b/src/variopt/algorithms/population/csa/engine/tell.py @@ -188,12 +188,30 @@ def apply_tell( ) updated_indices = batch_result.changed_indices significant_update_indices = batch_result.significant_update_indices + entry_count_before_removal = ( + len(batch_result.bank.entries) + len(batch_result.removed_indices) + ) + progression_state = batch_result.state if updated_indices: - engine_state = replace( - engine_state, - progression_state=engine_state.progression_state.without_updated_seed_mask( - updated_indices, - ), + progression_state = progression_state.without_updated_seed_mask( + updated_indices, + ) + if batch_result.removed_indices: + progression_state = progression_state.remove_indices( + removed_indices=batch_result.removed_indices, + entry_count=len(batch_result.bank.entries), + ) + + selection_state = engine_state.selection_state + if significant_update_indices: + selection_state = selection_state.invalidate_for_bank_update( + updated_indices=significant_update_indices, + entry_count=entry_count_before_removal, + ) + if batch_result.removed_indices: + selection_state = selection_state.remove_indices( + removed_indices=batch_result.removed_indices, + entry_count=len(batch_result.bank.entries), ) engine_state = replace( @@ -204,7 +222,8 @@ def apply_tell( growth_state=batch_result.growth_state, clustering_state=batch_result.clustering_state, ), - progression_state=batch_result.state, + progression_state=progression_state, + selection_state=selection_state, scoring_state=replace( engine_state.scoring_state, model_state=batch_result.score_model_state, @@ -216,24 +235,6 @@ def apply_tell( ), ) - if batch_result.removed_indices: - engine_state = replace( - engine_state, - selection_state=engine_state.selection_state.remove_indices( - removed_indices=batch_result.removed_indices, - entry_count=len(engine_state.banking_state.bank.entries), - ), - ) - - if significant_update_indices: - engine_state = replace( - engine_state, - selection_state=engine_state.selection_state.invalidate_for_bank_update( - updated_indices=significant_update_indices, - entry_count=len(engine_state.banking_state.bank.entries), - ), - ) - engine_state = sync_reference_bank_if_uninitialized( engine_state, diversity_metric=diversity_metric, diff --git a/src/variopt/algorithms/population/csa/indexing.py b/src/variopt/algorithms/population/csa/indexing.py new file mode 100644 index 0000000..3f9ca0f --- /dev/null +++ b/src/variopt/algorithms/population/csa/indexing.py @@ -0,0 +1,57 @@ +"""Bank-index algebra shared by CSA state projections.""" + +from bisect import bisect_left +from collections.abc import Set as AbstractSet + + +def remap_indices_after_removal( + indices: AbstractSet[int], + *, + removed_indices: AbstractSet[int], + entry_count: int, +) -> frozenset[int]: + """Return indices remapped after removing entries from a bank snapshot. + + Parameters + ---------- + indices : collections.abc.Set[int] + Indices aligned to the previous bank snapshot. + removed_indices : collections.abc.Set[int] + Bank indices removed from the previous bank snapshot. Negative values + are ignored. + entry_count : int + Current bank size after removal. + + Returns + ------- + frozenset[int] + Indices aligned to the current bank snapshot. Removed indices and + indices outside the current bank size are dropped. + + Raises + ------ + ValueError + If ``entry_count`` is negative. + """ + if entry_count < 0: + msg = "entry_count must be non-negative" + raise ValueError(msg) + + if not indices: + return frozenset() + + ordered_removed_indices = tuple( + sorted(index for index in removed_indices if index >= 0) + ) + removed_index_set = frozenset(ordered_removed_indices) + remapped_indices: set[int] = set() + for index in indices: + if index < 0 or index in removed_index_set: + continue + + removed_before_count = bisect_left(ordered_removed_indices, index) + remapped_index = index - removed_before_count + if remapped_index < entry_count: + remapped_indices.add(remapped_index) + + return frozenset(remapped_indices) diff --git a/src/variopt/algorithms/population/csa/progression/stage.py b/src/variopt/algorithms/population/csa/progression/stage.py index 9333aad..0e4a9dd 100644 --- a/src/variopt/algorithms/population/csa/progression/stage.py +++ b/src/variopt/algorithms/population/csa/progression/stage.py @@ -7,6 +7,7 @@ from typing_extensions import Self from .....json_types import JSONDict, JSONValue, require_json_int, require_json_list +from ..indexing import remap_indices_after_removal @dataclass(frozen=True, slots=True) @@ -187,6 +188,51 @@ def without_updated_seed_mask(self, updated_indices: AbstractSet[int]) -> Self: partner_mask=self.partner_mask, ) + def remove_indices( + self, + *, + removed_indices: AbstractSet[int], + entry_count: int, + ) -> Self: + """Return a copy remapped after bank indices have been removed. + + Parameters + ---------- + removed_indices : collections.abc.Set[int] + Bank indices removed from the previous bank snapshot. + entry_count : int + Current bank size after removal. + + Returns + ------- + Self + Stage state whose masks are aligned with the current bank indexing. + + Raises + ------ + ValueError + If ``entry_count`` is negative. + """ + if entry_count < 0: + msg = "entry_count must be non-negative" + raise ValueError(msg) + + if not removed_indices: + return self + + return self.with_masks( + seed_mask=remap_indices_after_removal( + self.seed_mask, + removed_indices=removed_indices, + entry_count=entry_count, + ), + partner_mask=remap_indices_after_removal( + self.partner_mask, + removed_indices=removed_indices, + entry_count=entry_count, + ), + ) + def with_masks( self, *, diff --git a/src/variopt/algorithms/population/csa/progression/state.py b/src/variopt/algorithms/population/csa/progression/state.py index f603424..2f7eb77 100644 --- a/src/variopt/algorithms/population/csa/progression/state.py +++ b/src/variopt/algorithms/population/csa/progression/state.py @@ -15,6 +15,7 @@ require_json_list, require_json_mapping, ) +from ..indexing import remap_indices_after_removal from .cutoff.state import CSACutoffState from .stage import CSAStageState @@ -552,6 +553,52 @@ def without_updated_seed_mask(self, updated_indices: AbstractSet[int]) -> Self: ), ) + def remove_indices( + self, + *, + removed_indices: AbstractSet[int], + entry_count: int, + ) -> Self: + """Return a copy remapped after bank indices have been removed. + + Parameters + ---------- + removed_indices : collections.abc.Set[int] + Bank indices removed from the previous bank snapshot. + entry_count : int + Current bank size after removal. + + Returns + ------- + Self + Progression state whose stage and refresh masks are aligned with the + current bank indexing. + + Raises + ------ + ValueError + If ``entry_count`` is negative. + """ + if entry_count < 0: + msg = "entry_count must be non-negative" + raise ValueError(msg) + + if not removed_indices: + return self + + return replace( + self, + stage_state=self.stage_state.remove_indices( + removed_indices=removed_indices, + entry_count=entry_count, + ), + refresh_mask=remap_indices_after_removal( + self.refresh_mask, + removed_indices=removed_indices, + entry_count=entry_count, + ), + ) + def with_refresh_mask(self, refresh_mask: frozenset[int]) -> Self: """Return a copy with one replacement refresh newcomer mask. diff --git a/src/variopt/algorithms/population/csa/selection/state.py b/src/variopt/algorithms/population/csa/selection/state.py index e9136fc..58ddadd 100644 --- a/src/variopt/algorithms/population/csa/selection/state.py +++ b/src/variopt/algorithms/population/csa/selection/state.py @@ -7,6 +7,7 @@ from typing_extensions import Self from .....json_types import JSONDict, JSONValue +from ..indexing import remap_indices_after_removal EMPTY_IGNORED_INDICES: frozenset[int] = frozenset() @@ -280,24 +281,12 @@ def remove_indices( if not removed_indices: return self - removed_index_set = frozenset(index for index in removed_indices if index >= 0) - - def _remap_index(index: int) -> int | None: - if index in removed_index_set: - return None - - return index - sum( - 1 - for removed_index in removed_index_set - if removed_index < index - ) - - remapped_used_indices = frozenset( - remapped_index - for index in self.used_entry_indices - if (remapped_index := _remap_index(index)) is not None - and 0 <= remapped_index < entry_count + remapped_used_indices = remap_indices_after_removal( + self.used_entry_indices, + removed_indices=removed_indices, + entry_count=entry_count, ) + removed_index_set = frozenset(index for index in removed_indices if index >= 0) resized_status = [ is_used for index, is_used in enumerate(self.resize_bank_status(entry_count=entry_count + len(removed_index_set))) diff --git a/tests/csa/test_csa_banking.py b/tests/csa/test_csa_banking.py index 416a8e0..6967509 100644 --- a/tests/csa/test_csa_banking.py +++ b/tests/csa/test_csa_banking.py @@ -16,10 +16,12 @@ CSABankGrowthPolicy, CSABankUpdatePolicy, CSABiasedPotential, + CSACutoffSchedule, CSACutoffState, CSANicheQualityPolicy, CSAOptimizerTestCase, CSAScoreModel, + CSAStageState, DiversityMetric, IntegerSpace, NaNDistance, @@ -41,11 +43,19 @@ CSAClusteringPolicy, CSAClusteringState, ) +from variopt.algorithms.population.csa.banking.growth import CSABankGrowthState from variopt.algorithms.population.csa.banking.queries import ( BankDistanceWorkspace, best_mean_niche_scores, crowding_aware_scores, ) +from variopt.algorithms.population.csa.banking.update.logic import ( + apply_bank_update_batch, +) +from variopt.algorithms.population.csa.progression.state import CSAProgressionState +from variopt.algorithms.population.csa.scoring.acceptance_state import ( + CSAAcceptanceState, +) from variopt.algorithms.population.csa.scoring.model_state import CSAScoreModelState from variopt.algorithms.population.csa.selection.state import SeedSelectionState from variopt.json_types import JSONValue @@ -1176,6 +1186,20 @@ def test_growth_policy_reduces_oversized_bank_after_batch(self) -> None: minimum_distance_cutoff=2.0, cutoff_recover_limit=2.0, ) + optimizer.engine_state = replace( + optimizer.engine_state, + selection_state=SeedSelectionState( + used_entry_indices=frozenset({0, 2}), + bank_status=(True, False, True), + ), + progression_state=replace( + optimizer.engine_state.progression_state, + stage_state=optimizer.engine_state.progression_state.stage_state.with_masks( + seed_mask=frozenset({0, 2}), + partner_mask=frozenset({1, 2}), + ), + ).with_refresh_mask(frozenset({2})), + ) proposal = Proposal(candidate=30, proposal_id="p-1") optimizer.pending_by_id = {"p-1": proposal} @@ -1194,6 +1218,139 @@ def test_growth_policy_reduces_oversized_bank_after_batch(self) -> None: assert ( tuple(entry.value for entry in optimizer.bank.entries) == (0.0, 0.5) ) + assert optimizer.selection_state.used_entry_indices == frozenset({0}) + assert optimizer.selection_state.bank_status == (True, False) + assert optimizer.engine_state.progression_state.stage_state.seed_mask == frozenset( + {0}, + ) + assert optimizer.engine_state.progression_state.stage_state.partner_mask == frozenset( + {1}, + ) + assert optimizer.engine_state.progression_state.refresh_mask == frozenset() + + def test_update_indices_are_reported_before_energy_cut_removal(self) -> None: + growth_policy = CSABankGrowthPolicy( + enabled=True, + maximum_capacity=3, + initial_energy_gap_limit=1.0, + ) + + result = apply_bank_update_batch( + bank=Bank( + capacity=3, + entries=( + BankEntry(candidate=0, value=10.0, proposal_id="b-0"), + BankEntry(candidate=10, value=0.0, proposal_id="b-1"), + BankEntry(candidate=20, value=5.0, proposal_id="b-2"), + ), + ), + state=CSAProgressionState( + cutoff_state=CSACutoffState( + distance_cutoff=2.0, + minimum_distance_cutoff=2.0, + cutoff_recover_limit=2.0, + ), + stage_state=CSAStageState(base_capacity=2, max_capacity=3), + ), + observations=( + Observation( + proposal=Proposal(candidate=21, proposal_id="p-1"), + candidate=21, + value=-1.0, + score=-1.0, + ), + ), + diversity_metric=AbsoluteDistance(), + infer_average_distance=lambda entries: 2.0, + infer_score_gap=lambda entries: 11.0, + cutoff_schedule=CSACutoffSchedule( + initial_distance_cutoff=2.0, + minimum_distance_cutoff=2.0, + ), + update_policy=CSABankUpdatePolicy(), + acceptance_state=CSAAcceptanceState.from_policy(CSAAcceptancePolicy()), + score_model_state=CSAScoreModelState(score_model=CSAScoreModel()), + growth_state=CSABankGrowthState[int]( + policy=growth_policy, + active_energy_gap_limit=growth_policy.initial_energy_gap_limit, + ), + clustering_state=CSAClusteringState(policy=CSAClusteringPolicy(enabled=False)), + base_bank_capacity=2, + masked_seed_indices=frozenset(), + random_state=None, + ) + + assert result.removed_indices == frozenset({0}) + assert result.changed_indices == frozenset({2}) + assert result.significant_update_indices == frozenset({2}) + assert tuple(entry.candidate for entry in result.bank.entries) == (10, 21) + + def test_update_then_energy_cut_remaps_selection_and_progression_masks(self) -> None: + optimizer = make_optimizer( + space=IntegerSpace(low=0, high=100), + diversity_metric=AbsoluteDistance(), + variation_operator=RepeatParent(), + bank_capacity=2, + growth_policy=CSABankGrowthPolicy( + enabled=True, + maximum_capacity=3, + initial_energy_gap_limit=1.0, + ), + random_state=0, + ) + optimizer.bank = Bank( + capacity=3, + entries=( + BankEntry(candidate=0, value=10.0, proposal_id="b-0"), + BankEntry(candidate=10, value=0.0, proposal_id="b-1"), + BankEntry(candidate=20, value=5.0, proposal_id="b-2"), + ), + ) + optimizer.reference_bank = ReferenceBank( + capacity=3, + entries=optimizer.bank.entries, + ) + optimizer.cutoff_state = CSACutoffState( + distance_cutoff=2.0, + minimum_distance_cutoff=2.0, + cutoff_recover_limit=2.0, + ) + optimizer.engine_state = replace( + optimizer.engine_state, + selection_state=SeedSelectionState( + used_entry_indices=frozenset({1, 2}), + bank_status=(True, True, True), + ), + progression_state=replace( + optimizer.engine_state.progression_state, + stage_state=optimizer.engine_state.progression_state.stage_state.with_masks( + seed_mask=frozenset({2}), + partner_mask=frozenset({1, 2}), + ), + ).with_refresh_mask(frozenset({2})), + ) + proposal = Proposal(candidate=21, proposal_id="p-1") + optimizer.pending_by_id = {"p-1": proposal} + + optimizer.tell( + ( + Observation( + proposal=proposal, + candidate=21, + value=-1.0, + score=-1.0, + ), + ) + ) + + assert tuple(entry.candidate for entry in optimizer.bank.entries) == (10, 21) + assert optimizer.selection_state.used_entry_indices == frozenset({0}) + assert optimizer.selection_state.bank_status == (True, False) + assert optimizer.engine_state.progression_state.stage_state.seed_mask == frozenset() + assert optimizer.engine_state.progression_state.stage_state.partner_mask == frozenset( + {0, 1}, + ) + assert optimizer.engine_state.progression_state.refresh_mask == frozenset() def test_initial_fill_sorts_bank_and_reference_bank_by_score(self) -> None: problem = Problem( diff --git a/tests/csa/test_csa_clustering.py b/tests/csa/test_csa_clustering.py index b54282e..d28b7a8 100644 --- a/tests/csa/test_csa_clustering.py +++ b/tests/csa/test_csa_clustering.py @@ -345,6 +345,54 @@ def test_appended_close_candidate_inherits_nearest_cluster(self) -> None: assert next_runtime.cluster_labels == (1, 2, 1) + def test_close_replacement_inherits_nearest_cluster_when_target_differs(self) -> None: + runtime: CSAClusteringState[int] = CSAClusteringState( + policy=CSAClusteringPolicy(enabled=True), + cluster_distance=3.0, + cluster_labels=(1, 1, 2, 2, 2), + ) + + next_runtime = runtime.register_admission( + admitted_index=4, + nearest_index=1, + nearest_distance=1.0, + appended=False, + ) + + assert next_runtime.cluster_labels == (1, 1, 2, 2, 1) + + def test_cutoff_distance_replacement_inherits_nearest_cluster(self) -> None: + runtime: CSAClusteringState[int] = CSAClusteringState( + policy=CSAClusteringPolicy(enabled=True), + cluster_distance=3.0, + cluster_labels=(1, 1, 2), + ) + + next_runtime = runtime.register_admission( + admitted_index=2, + nearest_index=0, + nearest_distance=3.0, + appended=False, + ) + + assert next_runtime.cluster_labels == (1, 1, 1) + + def test_far_replacement_opens_new_cluster(self) -> None: + runtime: CSAClusteringState[int] = CSAClusteringState( + policy=CSAClusteringPolicy(enabled=True), + cluster_distance=3.0, + cluster_labels=(1, 1, 2), + ) + + next_runtime = runtime.register_admission( + admitted_index=1, + nearest_index=0, + nearest_distance=4.0, + appended=False, + ) + + assert next_runtime.cluster_labels == (1, 3, 2) + def test_largest_cluster_mode_separates_comparison_and_removal_targets(self) -> None: runtime: CSAClusteringState[int] = CSAClusteringState( policy=CSAClusteringPolicy( @@ -397,6 +445,7 @@ def test_cluster_update_largest_cluster_mode_replaces_largest_cluster_worst(self assert batch_result.bank.entries[4].candidate == 2 assert batch_result.bank.entries[4].value == 8.0 assert batch_result.bank.entries[1].candidate == 1 + assert batch_result.clustering_state.cluster_labels == (1, 1, 2, 2, 1) def test_cluster_update_current_cluster_mode_replaces_current_cluster_worst(self) -> None: batch_result = run_cluster_batch( diff --git a/tests/csa/test_csa_engine_state.py b/tests/csa/test_csa_engine_state.py index 82c1a61..aa04a9d 100644 --- a/tests/csa/test_csa_engine_state.py +++ b/tests/csa/test_csa_engine_state.py @@ -237,6 +237,25 @@ def test_without_updated_seed_mask_removes_refresh_mask_entries(self) -> None: assert next_state.refresh_mask == frozenset({0}) + def test_progression_masks_remap_after_bank_removal(self) -> None: + state = build_engine_state() + progression_state = replace( + state.progression_state, + stage_state=state.progression_state.stage_state.with_masks( + seed_mask=frozenset({0, 2, 4}), + partner_mask=frozenset({1, 3, 4}), + ), + ).with_refresh_mask(frozenset({2, 4})) + + next_state = progression_state.remove_indices( + removed_indices=frozenset({1, 4}), + entry_count=3, + ) + + assert next_state.stage_state.seed_mask == frozenset({0, 1}) + assert next_state.stage_state.partner_mask == frozenset({2}) + assert next_state.refresh_mask == frozenset({1}) + class CSAAskEngineTests: """Regression tests for extracted ask-side engine planning."""