Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def register_admission(
elif appended:
admitted_label = 1 if not labels else labels[nearest_index]
else:
admitted_label = labels[admitted_index]
admitted_label = labels[nearest_index]

if appended:
labels.append(admitted_label)
Expand Down
18 changes: 9 additions & 9 deletions src/variopt/algorithms/population/csa/banking/update/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ def apply_bank_update_batch(
significant_update_indices=batch_significant_update_indices,
)

final_changed_indices = changed_indices(
previous_bank=previous_bank,
next_bank=shadow_bank,
)
final_significant_update_indices = significant_update_indices(
previous_bank=previous_bank,
next_bank=shadow_bank,
minimum_significant_score_gap=update_policy.minimum_significant_score_gap,
)
removed_indices: frozenset[int] = frozenset()
(
shadow_bank,
Expand All @@ -200,15 +209,6 @@ def apply_bank_update_batch(
entries=shadow_bank.entries,
diversity_metric=diversity_metric,
)
final_changed_indices = changed_indices(
previous_bank=previous_bank,
next_bank=shadow_bank,
)
final_significant_update_indices = significant_update_indices(
previous_bank=previous_bank,
next_bank=shadow_bank,
minimum_significant_score_gap=update_policy.minimum_significant_score_gap,
)
return BankUpdateResult(
bank=shadow_bank,
state=shadow_state,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ class BankUpdateResult(FrozenGenericSlotsCompat, Generic[CandidateT]):
trace_state : CSAEventTraceState[CandidateT] | None
Optional updated trace reducer state.
changed_indices : frozenset[int]
Bank indices whose entries changed at all.
Bank indices whose entries changed before post-batch removals are applied.
significant_update_indices : frozenset[int]
Bank indices whose score change exceeded the significance floor.
Changed bank indices whose score gap exceeded the significance floor
before post-batch removals are applied.
removed_indices : frozenset[int]
Bank indices removed from the previous bank snapshot.
"""
Expand Down
49 changes: 25 additions & 24 deletions src/variopt/algorithms/population/csa/engine/tell.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,30 @@ def apply_tell(
)
updated_indices = batch_result.changed_indices
significant_update_indices = batch_result.significant_update_indices
entry_count_before_removal = (
len(batch_result.bank.entries) + len(batch_result.removed_indices)
)
progression_state = batch_result.state
if updated_indices:
engine_state = replace(
engine_state,
progression_state=engine_state.progression_state.without_updated_seed_mask(
updated_indices,
),
progression_state = progression_state.without_updated_seed_mask(
updated_indices,
)
if batch_result.removed_indices:
progression_state = progression_state.remove_indices(
removed_indices=batch_result.removed_indices,
entry_count=len(batch_result.bank.entries),
)

selection_state = engine_state.selection_state
if significant_update_indices:
selection_state = selection_state.invalidate_for_bank_update(
updated_indices=significant_update_indices,
entry_count=entry_count_before_removal,
)
if batch_result.removed_indices:
selection_state = selection_state.remove_indices(
removed_indices=batch_result.removed_indices,
entry_count=len(batch_result.bank.entries),
)

engine_state = replace(
Expand All @@ -204,7 +222,8 @@ def apply_tell(
growth_state=batch_result.growth_state,
clustering_state=batch_result.clustering_state,
),
progression_state=batch_result.state,
progression_state=progression_state,
selection_state=selection_state,
scoring_state=replace(
engine_state.scoring_state,
model_state=batch_result.score_model_state,
Expand All @@ -216,24 +235,6 @@ def apply_tell(
),
)

if batch_result.removed_indices:
engine_state = replace(
engine_state,
selection_state=engine_state.selection_state.remove_indices(
removed_indices=batch_result.removed_indices,
entry_count=len(engine_state.banking_state.bank.entries),
),
)

if significant_update_indices:
engine_state = replace(
engine_state,
selection_state=engine_state.selection_state.invalidate_for_bank_update(
updated_indices=significant_update_indices,
entry_count=len(engine_state.banking_state.bank.entries),
),
)

engine_state = sync_reference_bank_if_uninitialized(
engine_state,
diversity_metric=diversity_metric,
Expand Down
57 changes: 57 additions & 0 deletions src/variopt/algorithms/population/csa/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Bank-index algebra shared by CSA state projections."""

from bisect import bisect_left
from collections.abc import Set as AbstractSet


def remap_indices_after_removal(
indices: AbstractSet[int],
*,
removed_indices: AbstractSet[int],
entry_count: int,
) -> frozenset[int]:
"""Return indices remapped after removing entries from a bank snapshot.

Parameters
----------
indices : collections.abc.Set[int]
Indices aligned to the previous bank snapshot.
removed_indices : collections.abc.Set[int]
Bank indices removed from the previous bank snapshot. Negative values
are ignored.
entry_count : int
Current bank size after removal.

Returns
-------
frozenset[int]
Indices aligned to the current bank snapshot. Removed indices and
indices outside the current bank size are dropped.

Raises
------
ValueError
If ``entry_count`` is negative.
"""
if entry_count < 0:
msg = "entry_count must be non-negative"
raise ValueError(msg)

if not indices:
return frozenset()

ordered_removed_indices = tuple(
sorted(index for index in removed_indices if index >= 0)
)
removed_index_set = frozenset(ordered_removed_indices)
remapped_indices: set[int] = set()
for index in indices:
if index < 0 or index in removed_index_set:
continue

removed_before_count = bisect_left(ordered_removed_indices, index)
remapped_index = index - removed_before_count
if remapped_index < entry_count:
remapped_indices.add(remapped_index)

return frozenset(remapped_indices)
46 changes: 46 additions & 0 deletions src/variopt/algorithms/population/csa/progression/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing_extensions import Self

from .....json_types import JSONDict, JSONValue, require_json_int, require_json_list
from ..indexing import remap_indices_after_removal


@dataclass(frozen=True, slots=True)
Expand Down Expand Up @@ -187,6 +188,51 @@ def without_updated_seed_mask(self, updated_indices: AbstractSet[int]) -> Self:
partner_mask=self.partner_mask,
)

def remove_indices(
self,
*,
removed_indices: AbstractSet[int],
entry_count: int,
) -> Self:
"""Return a copy remapped after bank indices have been removed.

Parameters
----------
removed_indices : collections.abc.Set[int]
Bank indices removed from the previous bank snapshot.
entry_count : int
Current bank size after removal.

Returns
-------
Self
Stage state whose masks are aligned with the current bank indexing.

Raises
------
ValueError
If ``entry_count`` is negative.
"""
if entry_count < 0:
msg = "entry_count must be non-negative"
raise ValueError(msg)

if not removed_indices:
return self

return self.with_masks(
seed_mask=remap_indices_after_removal(
self.seed_mask,
removed_indices=removed_indices,
entry_count=entry_count,
),
partner_mask=remap_indices_after_removal(
self.partner_mask,
removed_indices=removed_indices,
entry_count=entry_count,
),
)

def with_masks(
self,
*,
Expand Down
47 changes: 47 additions & 0 deletions src/variopt/algorithms/population/csa/progression/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
require_json_list,
require_json_mapping,
)
from ..indexing import remap_indices_after_removal
from .cutoff.state import CSACutoffState
from .stage import CSAStageState

Expand Down Expand Up @@ -552,6 +553,52 @@ def without_updated_seed_mask(self, updated_indices: AbstractSet[int]) -> Self:
),
)

def remove_indices(
self,
*,
removed_indices: AbstractSet[int],
entry_count: int,
) -> Self:
"""Return a copy remapped after bank indices have been removed.

Parameters
----------
removed_indices : collections.abc.Set[int]
Bank indices removed from the previous bank snapshot.
entry_count : int
Current bank size after removal.

Returns
-------
Self
Progression state whose stage and refresh masks are aligned with the
current bank indexing.

Raises
------
ValueError
If ``entry_count`` is negative.
"""
if entry_count < 0:
msg = "entry_count must be non-negative"
raise ValueError(msg)

if not removed_indices:
return self

return replace(
self,
stage_state=self.stage_state.remove_indices(
removed_indices=removed_indices,
entry_count=entry_count,
),
refresh_mask=remap_indices_after_removal(
self.refresh_mask,
removed_indices=removed_indices,
entry_count=entry_count,
),
)

def with_refresh_mask(self, refresh_mask: frozenset[int]) -> Self:
"""Return a copy with one replacement refresh newcomer mask.

Expand Down
23 changes: 6 additions & 17 deletions src/variopt/algorithms/population/csa/selection/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing_extensions import Self

from .....json_types import JSONDict, JSONValue
from ..indexing import remap_indices_after_removal

EMPTY_IGNORED_INDICES: frozenset[int] = frozenset()

Expand Down Expand Up @@ -280,24 +281,12 @@ def remove_indices(
if not removed_indices:
return self

removed_index_set = frozenset(index for index in removed_indices if index >= 0)

def _remap_index(index: int) -> int | None:
if index in removed_index_set:
return None

return index - sum(
1
for removed_index in removed_index_set
if removed_index < index
)

remapped_used_indices = frozenset(
remapped_index
for index in self.used_entry_indices
if (remapped_index := _remap_index(index)) is not None
and 0 <= remapped_index < entry_count
remapped_used_indices = remap_indices_after_removal(
self.used_entry_indices,
removed_indices=removed_indices,
entry_count=entry_count,
)
removed_index_set = frozenset(index for index in removed_indices if index >= 0)
resized_status = [
is_used
for index, is_used in enumerate(self.resize_bank_status(entry_count=entry_count + len(removed_index_set)))
Expand Down
Loading