Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/variopt/algorithms/population/csa/banking/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from variopt.generic_runtime import FrozenGenericSlotsCompat

from .....artifacts import Observation
from .....json_types import JSONDict, JSONValue
from .....json_types import JSONDict, JSONValue, require_json_bool
from .....typevars import CandidateT
from .bank import Bank, BankEntry

Expand All @@ -24,10 +24,13 @@ class ReferenceBank(FrozenGenericSlotsCompat, Generic[CandidateT]):
Maximum number of entries the reference bank may store.
entries : tuple[BankEntry[CandidateT], ...], optional
Entries currently stored in the reference bank.
initialized : bool, default=False
Whether the entries represent a completed reference snapshot.
"""

capacity: int
entries: tuple[BankEntry[CandidateT], ...] = ()
initialized: bool = False

def __post_init__(self) -> None:
"""Validate reference-bank state.
Expand All @@ -45,6 +48,13 @@ def __post_init__(self) -> None:
msg = "entries must not exceed capacity"
raise ValueError(msg)

if self.initialized and not self.is_full:
msg = "initialized reference banks must be full"
raise ValueError(msg)

if not self.initialized and self.is_full:
object.__setattr__(self, "initialized", True)

@property
def is_full(self) -> bool:
"""Return whether the reference bank has reached capacity.
Expand Down Expand Up @@ -79,6 +89,7 @@ def to_dict(
entry.to_dict(candidate_to_dict=candidate_to_dict)
for entry in self.entries
],
"initialized": self.initialized,
}

@classmethod
Expand Down Expand Up @@ -110,12 +121,21 @@ def from_dict(
"""
capacity = data.get("capacity")
raw_entries = data.get("entries")
raw_initialized = data.get("initialized")
if not isinstance(capacity, int):
msg = "reference bank snapshot requires integer capacity"
raise TypeError(msg)
if not isinstance(raw_entries, list):
msg = "reference bank snapshot requires entry list"
raise TypeError(msg)
initialized = (
None
if raw_initialized is None
else require_json_bool(
raw_initialized,
field_name="reference bank initialized",
)
)

entries: list[BankEntry[CandidateT]] = []
for raw_entry in raw_entries:
Expand All @@ -132,6 +152,11 @@ def from_dict(
return cls(
capacity=capacity,
entries=tuple(entries),
initialized=(
len(entries) >= capacity
if initialized is None
else initialized
),
)


Expand Down Expand Up @@ -287,6 +312,7 @@ def build_reference_bank_from_bank(
return ReferenceBank(
capacity=bank.capacity,
entries=sorted_entries,
initialized=True,
)


Expand Down Expand Up @@ -351,6 +377,7 @@ def build_reference_bank_from_refresh_pool(
return ReferenceBank(
capacity=capacity,
entries=preserved_entry_tuple + selected_entries,
initialized=True,
)


Expand Down
6 changes: 1 addition & 5 deletions src/variopt/algorithms/population/csa/engine/boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,7 @@ def sync_reference_bank_if_uninitialized(
if not engine_state.banking_state.bank.is_full:
return engine_state

if (
engine_state.banking_state.reference_bank.is_full
and len(engine_state.banking_state.reference_bank.entries)
== len(engine_state.banking_state.bank.entries)
):
if engine_state.banking_state.reference_bank.initialized:
return engine_state

bank = build_sorted_bank_from_bank(engine_state.banking_state.bank)
Expand Down
78 changes: 77 additions & 1 deletion tests/csa/test_csa_banking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for CSA banking, score-model, and admission semantics."""

from dataclasses import replace
from typing import Literal, cast

import pytest
Expand Down Expand Up @@ -36,12 +37,18 @@
make_optimizer,
significant_update_indices,
)
from variopt.algorithms.population.csa.banking.clustering import (
CSAClusteringPolicy,
CSAClusteringState,
)
from variopt.algorithms.population.csa.banking.queries import (
BankDistanceWorkspace,
best_mean_niche_scores,
crowding_aware_scores,
)
from variopt.algorithms.population.csa.scoring.model_state import CSAScoreModelState
from variopt.algorithms.population.csa.selection.state import SeedSelectionState
from variopt.json_types import JSONValue


class CountingDistance(DiversityMetric[int]):
Expand Down Expand Up @@ -730,6 +737,36 @@ def test_admit_rejects_negative_diversity_distance(self) -> None:
class CSABankingTests(CSAOptimizerTestCase):
"""White-box tests for CSA banking and score-model state transitions."""

def test_reference_bank_legacy_snapshot_derives_initialized_state(self) -> None:
def candidate_from_dict(value: JSONValue) -> int:
if type(value) is not int:
msg = "candidate must be an integer"
raise TypeError(msg)
return value

full_reference = ReferenceBank[int].from_dict(
{
"capacity": 2,
"entries": [
{"candidate": 0, "value": 0.0, "proposal_id": "b-0"},
{"candidate": 1, "value": 1.0, "proposal_id": "b-1"},
],
},
candidate_from_dict=candidate_from_dict,
)
partial_reference = ReferenceBank[int].from_dict(
{
"capacity": 2,
"entries": [
{"candidate": 0, "value": 0.0, "proposal_id": "b-0"},
],
},
candidate_from_dict=candidate_from_dict,
)

assert full_reference.initialized
assert not partial_reference.initialized

def test_significant_update_threshold_ignores_small_score_changes(self) -> None:
optimizer = make_optimizer(
space=IntegerSpace(low=0, high=100),
Expand Down Expand Up @@ -1052,6 +1089,29 @@ def test_growth_policy_can_append_far_candidate(self) -> None:
),
distance_cutoff=2.0,
)
optimizer.engine_state = replace(
optimizer.engine_state,
banking_state=replace(
optimizer.engine_state.banking_state,
clustering_state=CSAClusteringState[int](
policy=CSAClusteringPolicy(enabled=True),
cluster_distance=2.0,
cluster_labels=(1, 2),
),
),
selection_state=SeedSelectionState(
used_entry_indices=frozenset({1}),
bank_status=(False, True),
),
progression_state=replace(
optimizer.engine_state.progression_state,
stage_state=optimizer.engine_state.progression_state.stage_state.with_masks(
seed_mask=frozenset({1}),
partner_mask=frozenset({0}),
),
).with_refresh_mask(frozenset({1})),
)
reference_entries = optimizer.reference_bank.entries
proposal = Proposal(candidate=20, proposal_id="p-1")
optimizer.pending_by_id = {"p-1": proposal}

Expand All @@ -1068,8 +1128,23 @@ def test_growth_policy_can_append_far_candidate(self) -> None:

assert optimizer.bank.capacity == 3
assert (
tuple(entry.value for entry in optimizer.bank.entries) == (0.0, 5.0, 10.0)
tuple(entry.candidate for entry in optimizer.bank.entries) == (0, 10, 20)
)
assert optimizer.reference_bank.entries == reference_entries
assert (
tuple(
zip(
(entry.candidate for entry in optimizer.bank.entries),
optimizer.engine_state.banking_state.clustering_state.cluster_labels,
strict=True,
)
)
== ((0, 1), (10, 2), (20, 3))
)
assert optimizer.selection_state.used_entry_indices == frozenset({1})
assert optimizer.selection_state.bank_status == (False, True, False)
assert optimizer.engine_state.progression_state.seed_mask == frozenset({1})
assert optimizer.engine_state.progression_state.partner_mask == frozenset({0, 1})

def test_growth_policy_reduces_oversized_bank_after_batch(self) -> None:
optimizer = make_optimizer(
Expand Down Expand Up @@ -1149,3 +1224,4 @@ def test_initial_fill_sorts_bank_and_reference_bank_by_score(self) -> None:
tuple(entry.candidate for entry in optimizer.reference_bank.entries)
== (1, 5, 9)
)
assert optimizer.reference_bank.initialized