From 3258113f145c7326d14b6d45e7a5b2f7af3ef76c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Mar 2026 14:55:26 +0000 Subject: [PATCH 1/5] Initial plan From baeb6222255ffb31272ff073687396bda8b5abc7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:02:35 +0000 Subject: [PATCH 2/5] Plan: add comprehensive parameter group stress tests Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- test_graph | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 test_graph diff --git a/test_graph b/test_graph new file mode 100644 index 0000000..f58b6f4 --- /dev/null +++ b/test_graph @@ -0,0 +1,19 @@ +strict digraph { + node [color=black shape=folder style=solid] + 139902456415568 [label="NodeList(NodeList)[5]"] + node [color=black fillcolor=grey90 shape=box style=filled] + 139902630298928 [label=ptest1] + node [color=black fillcolor=grey90 shape=box style=filled] + 139902592574736 [label=ptest2] + node [color=black shape=ellipse style=solid] + 139902592574976 [label="Module(mtest1)"] + node [color=black shape=ellipse style=solid] + 139902592576896 [label="Module(mtest2)"] + node [color=black shape=ellipse style=solid] + 139902592579104 [label="Module(mtest3)"] + 139902456415568 -> 139902630298928 + 139902456415568 -> 139902592574736 + 139902456415568 -> 139902592574976 + 139902456415568 -> 139902592576896 + 139902456415568 -> 139902592579104 +} From 6beaaf4b332fb792d85cf4bb6b045122f13d8e0a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:16:25 +0000 Subject: [PATCH 3/5] Add comprehensive parameter group stress tests and fix from_valid bug in fill_params Add 65 tests covering parameter groups across: - Basic group creation and graph tracking - Hierarchical models (top-level and within hierarchy) - get_values/set_values with all schemes (array, list, dict) - Varying batch dimensions (independent per group) - find_param/find_index - to_valid/from_valid and ValidContext - @forward decorator and fill_params - ActiveContext - active_cache - NodeList/NodeTuple collections - Pointer params - Dynamic/static transitions - OverrideParam - Edge cases (many groups, non-contiguous, reassignment) Fix bug in fill_params where from_valid was called without group= parameter, causing IndexError when using ValidContext with multi-group parameters. Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- src/caskade/module.py | 2 +- test_graph | 19 - tests/test_param_groups.py | 833 +++++++++++++++++++++++++++++++++++++ 3 files changed, 834 insertions(+), 20 deletions(-) delete mode 100644 test_graph create mode 100644 tests/test_param_groups.py diff --git a/src/caskade/module.py b/src/caskade/module.py index 774faa8..43f33b7 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -170,7 +170,7 @@ def fill_params(self, params: Union[ArrayLike, Sequence, Mapping], dynamic=True) for group, params_g in zip(self.dynamic_param_groups, params): param_list_g = tuple(p for p in param_list if p.group == group) if self.valid_context: - params_g = self.from_valid(params_g, param_list_g) + params_g = self.from_valid(params_g, param_list_g, group=group) self._set_values(params_g, param_list_g, attribute="_value") def clear_state(self): diff --git a/test_graph b/test_graph deleted file mode 100644 index f58b6f4..0000000 --- a/test_graph +++ /dev/null @@ -1,19 +0,0 @@ -strict digraph { - node [color=black shape=folder style=solid] - 139902456415568 [label="NodeList(NodeList)[5]"] - node [color=black fillcolor=grey90 shape=box style=filled] - 139902630298928 [label=ptest1] - node [color=black fillcolor=grey90 shape=box style=filled] - 139902592574736 [label=ptest2] - node [color=black shape=ellipse style=solid] - 139902592574976 [label="Module(mtest1)"] - node [color=black shape=ellipse style=solid] - 139902592576896 [label="Module(mtest2)"] - node [color=black shape=ellipse style=solid] - 139902592579104 [label="Module(mtest3)"] - 139902456415568 -> 139902630298928 - 139902456415568 -> 139902592574736 - 139902456415568 -> 139902592574976 - 139902456415568 -> 139902592576896 - 139902456415568 -> 139902592579104 -} diff --git a/tests/test_param_groups.py b/tests/test_param_groups.py new file mode 100644 index 0000000..5daafff --- /dev/null +++ b/tests/test_param_groups.py @@ -0,0 +1,833 @@ +"""Stress tests for parameter group functionality. + +Tests parameter groups across many caskade features including hierarchical +models, batch dimensions, get/set values, finders, valid transforms, +forward decorator, active context, active cache, collections, pointer +params, dynamic/static transitions, and override params. +""" + +import numpy as np +import pytest + +from caskade import ( + Module, + Param, + forward, + active_cache, + ActiveContext, + ValidContext, + OverrideParam, + ActiveStateError, + ParamConfigurationError, + NodeList, + NodeTuple, + backend, +) + + +# ────────────────────────────────────────────────────────────────────── +# Helper modules used across multiple tests +# ────────────────────────────────────────────────────────────────────── +class Inner(Module): + def __init__(self, name=None): + super().__init__(name) + self.a = Param("a", 1.0, shape=()) + self.b = Param("b", [2.0, 3.0], shape=(2,)) + + @forward + def compute(self, x, a, b): + return backend.sum(x + a + b) + + +class Outer(Module): + def __init__(self, inner, name=None): + super().__init__(name) + self.inner = inner + self.c = Param("c", 4.0, shape=()) + self.d = Param("d", [5.0, 6.0], shape=(2,)) + + @forward + def run(self, x, c, d): + return backend.sum(x + c + d) + self.inner.compute(x) + + +def _make_grouped_outer(): + """Create an Outer model with all params dynamic and group split.""" + inner = Inner() + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + outer.c.group = 1 + outer.d.group = 1 + return outer + + +# ────────────────────────────────────────────────────────────────────── +# 1. Basic group creation and graph tracking +# ────────────────────────────────────────────────────────────────────── +class TestGroupBasics: + + def test_default_group(self): + """All dynamic params default to group 0.""" + m = Outer(Inner()) + m.to_dynamic(False) + assert all(p.group == 0 for p in m.dynamic_params) + assert m.dynamic_param_groups == (0,) + + def test_assign_group_at_creation(self): + """Groups assigned at Param creation.""" + p = Param("p", 1.0, group=2) + assert p.group == 2 + + def test_assign_group_later(self): + """Groups can be reassigned after creation.""" + m = Outer(Inner()) + m.to_dynamic(False) + m.c.group = 1 + assert m.c.group == 1 + assert 1 in m.dynamic_param_groups + + def test_multiple_groups(self): + """Multiple distinct groups tracked correctly.""" + m = Outer(Inner()) + m.to_dynamic(False) + m.inner.a.group = 0 + m.inner.b.group = 1 + m.c.group = 2 + m.d.group = 3 + assert m.dynamic_param_groups == (0, 1, 2, 3) + + def test_group_must_be_int(self): + """Group must be an integer.""" + p = Param("p", 1.0) + with pytest.raises(AssertionError): + p.group = 1.5 + with pytest.raises(AssertionError): + p.group = "bad" + + def test_reassign_group_updates_graph(self): + """Changing group triggers update_graph on parents.""" + m = Outer(Inner()) + m.to_dynamic(False) + assert m.dynamic_param_groups == (0,) + m.c.group = 5 + assert 5 in m.dynamic_param_groups + m.c.group = 0 + assert m.dynamic_param_groups == (0,) + + def test_param_order_by_group(self): + """param_order returns params organized by group.""" + m = _make_grouped_outer() + order = m.param_order() + assert isinstance(order, str) + lines = order.strip().split("\n") + assert len(lines) == 2 + + +# ────────────────────────────────────────────────────────────────────── +# 2. Groups in hierarchical models +# ────────────────────────────────────────────────────────────────────── +class TestGroupsHierarchical: + + def test_groups_at_top_level(self): + """Groups assigned at top-level module propagate correctly.""" + outer = _make_grouped_outer() + assert outer.dynamic_param_groups == (0, 1) + group0_params = [p for p in outer.dynamic_params if p.group == 0] + group1_params = [p for p in outer.dynamic_params if p.group == 1] + assert len(group0_params) == 2 # inner.a, inner.b + assert len(group1_params) == 2 # outer.c, outer.d + + def test_groups_within_hierarchy(self): + """Groups assigned to params within sub-modules.""" + inner = Inner() + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + inner.a.group = 1 + assert outer.dynamic_param_groups == (0, 1) + group1 = [p for p in outer.dynamic_params if p.group == 1] + assert len(group1) == 1 + assert group1[0] is inner.a + + def test_groups_mixed_hierarchy(self): + """Groups assigned across different levels of hierarchy.""" + inner = Inner() + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + inner.a.group = 1 + outer.c.group = 1 + assert outer.dynamic_param_groups == (0, 1) + group0 = [p for p in outer.dynamic_params if p.group == 0] + group1 = [p for p in outer.dynamic_params if p.group == 1] + assert len(group0) == 2 # inner.b, outer.d + assert len(group1) == 2 # inner.a, outer.c + + def test_deep_hierarchy(self): + """Groups work with deeply nested modules.""" + + class Deep(Module): + def __init__(self, child=None, name=None): + super().__init__(name) + self.p = Param("p", 1.0, shape=()) + if child is not None: + self.child = child + + @forward + def go(self, p): + if hasattr(self, "child"): + return p + self.child.go() + return p + + level3 = Deep(name="level3") + level2 = Deep(level3, name="level2") + level1 = Deep(level2, name="level1") + level1.to_dynamic(False) + + level1.p.group = 0 + level2.p.group = 1 + level3.p.group = 2 + + assert level1.dynamic_param_groups == (0, 1, 2) + + p0 = level1.get_values() + assert len(p0) == 3 + result = level1.go(p0) + assert backend.module.allclose(result, backend.make_array(3.0)) + + +# ────────────────────────────────────────────────────────────────────── +# 3. Groups with get_values / set_values (all schemes) +# ────────────────────────────────────────────────────────────────────── +class TestGroupsGetSetValues: + + @pytest.fixture + def grouped_model(self): + return _make_grouped_outer() + + @pytest.mark.parametrize("scheme", ["array", "list", "dict"]) + def test_get_values_multi_group(self, grouped_model, scheme): + """get_values returns list of per-group values when multiple groups.""" + vals = grouped_model.get_values(scheme) + assert isinstance(vals, list) + assert len(vals) == 2 + + @pytest.mark.parametrize("scheme", ["array", "list", "dict"]) + def test_get_values_single_group(self, scheme): + """get_values returns single value when only one group.""" + m = Outer(Inner()) + m.to_dynamic(False) + vals = m.get_values(scheme) + assert not isinstance(vals, list) or scheme == "list" + + def test_get_values_specific_group(self, grouped_model): + """get_values with group= returns only that group's values.""" + v0 = grouped_model.get_values("array", group=0) + v1 = grouped_model.get_values("array", group=1) + assert isinstance(v0, backend.array_type) + assert isinstance(v1, backend.array_type) + # Group 0 has inner.a (1) + inner.b (2) = 3 elements + assert v0.shape[-1] == 3 + # Group 1 has outer.c (1) + outer.d (2) = 3 elements + assert v1.shape[-1] == 3 + + def test_set_values_multi_group_array(self, grouped_model): + """set_values with multi-group using array scheme.""" + vals = grouped_model.get_values("array") + grouped_model.set_values(vals) + vals2 = grouped_model.get_values("array") + for v, v2 in zip(vals, vals2): + assert backend.module.allclose(v, v2) + + def test_set_values_multi_group_list(self, grouped_model): + """set_values with multi-group using list scheme.""" + vals = grouped_model.get_values("list") + grouped_model.set_values(vals) + vals2 = grouped_model.get_values("list") + assert len(vals) == len(vals2) + + def test_set_values_multi_group_dict(self, grouped_model): + """set_values with multi-group using dict scheme.""" + vals = grouped_model.get_values("dict") + grouped_model.set_values(vals) + vals2 = grouped_model.get_values("dict") + assert len(vals) == len(vals2) + + def test_round_trip_values_multi_group(self, grouped_model): + """Set then get values should produce identical results.""" + original = grouped_model.get_values("array") + scaled = [v * 2 for v in original] + grouped_model.set_values(scaled) + retrieved = grouped_model.get_values("array") + for s, r in zip(scaled, retrieved): + assert backend.module.allclose(s, r) + grouped_model.set_values(original) + restored = grouped_model.get_values("array") + for o, r in zip(original, restored): + assert backend.module.allclose(o, r) + + +# ────────────────────────────────────────────────────────────────────── +# 4. Groups with varying batch dimensions +# ────────────────────────────────────────────────────────────────────── +class TestGroupsBatchDims: + + def test_batch_dims_independent_per_group(self): + """Batch dimensions of separate groups are independent.""" + outer = _make_grouped_outer() + + outer.inner.a = np.ones((3,)) + outer.inner.b = np.ones((3, 2)) + + vals = outer.get_values("array") + assert len(vals) == 2 + assert vals[0].shape == (3, 3) + assert vals[1].shape == (3,) + + def test_batch_dims_both_groups_different_sizes(self): + """Both groups batched but with different batch sizes.""" + outer = _make_grouped_outer() + + outer.inner.a = np.ones((4,)) + outer.inner.b = np.ones((4, 2)) + + outer.c = np.ones((7,)) + outer.d = np.ones((7, 2)) + + vals = outer.get_values("array") + assert len(vals) == 2 + assert vals[0].shape == (4, 3) + assert vals[1].shape == (7, 3) + + def test_batch_dims_multi_dim_batch(self): + """Multi-dimensional batch shapes per group.""" + outer = _make_grouped_outer() + + outer.inner.a = np.ones((2, 3)) + outer.inner.b = np.ones((2, 3, 2)) + + vals = outer.get_values("array") + assert len(vals) == 2 + assert vals[0].shape == (2, 3, 3) + assert vals[1].shape == (3,) + + +# ────────────────────────────────────────────────────────────────────── +# 5. Groups with find_param / find_index +# ────────────────────────────────────────────────────────────────────── +class TestGroupsFinders: + + @pytest.fixture + def grouped_model(self): + return _make_grouped_outer() + + def test_find_param_with_group(self, grouped_model): + """find_param with explicit group argument.""" + p0, idx0 = grouped_model.find_param(0, group=0) + assert p0 is grouped_model.inner.a + p1, idx1 = grouped_model.find_param(0, group=1) + assert p1 is grouped_model.c + + def test_find_index_multi_group(self, grouped_model): + """find_index returns (group, index) tuple when multiple groups.""" + idx = grouped_model.find_index(grouped_model.inner.a) + assert isinstance(idx, tuple) + assert idx[0] == 0 # group + assert isinstance(idx[1], (int, slice)) + + def test_find_index_different_groups(self, grouped_model): + """find_index returns correct group for each param.""" + idx_a = grouped_model.find_index(grouped_model.inner.a) + idx_c = grouped_model.find_index(grouped_model.c) + assert idx_a[0] == 0 + assert idx_c[0] == 1 + + def test_find_index_list_scheme(self, grouped_model): + """find_index with list scheme returns (group, list_idx).""" + idx = grouped_model.find_index(grouped_model.inner.a, scheme="list") + assert isinstance(idx, tuple) + assert idx[0] == 0 + + def test_find_param_out_of_range(self, grouped_model): + """find_param raises IndexError for out-of-range index.""" + with pytest.raises(IndexError): + grouped_model.find_param(100, group=0) + + def test_find_index_unknown_param(self, grouped_model): + """find_index raises ValueError for param not in model.""" + with pytest.raises(ValueError): + grouped_model.find_index(Param("unknown", 1.0)) + + +# ────────────────────────────────────────────────────────────────────── +# 6. Groups with to_valid / from_valid and ValidContext +# ────────────────────────────────────────────────────────────────────── +class TestGroupsValid: + + @pytest.fixture + def bounded_grouped_model(self): + """Model with valid bounds and multiple groups.""" + inner = Inner() + inner.a.valid = (0, 10) + inner.b.valid = (0, 10) + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + outer.c.valid = (0, 10) + outer.d.valid = (0, 10) + outer.c.group = 1 + outer.d.group = 1 + return outer + + def test_to_valid_multi_group(self, bounded_grouped_model): + """to_valid returns list of valid params per group.""" + params = bounded_grouped_model.get_values() + valid = bounded_grouped_model.to_valid(params) + assert isinstance(valid, list) + assert len(valid) == 2 + + def test_from_valid_multi_group(self, bounded_grouped_model): + """from_valid round-trips multi-group params.""" + params = bounded_grouped_model.get_values() + valid = bounded_grouped_model.to_valid(params) + recovered = bounded_grouped_model.from_valid(valid) + for p, r in zip(params, recovered): + assert backend.module.allclose(p, r) + + @pytest.mark.parametrize("scheme", ["array", "list", "dict"]) + def test_valid_context_multi_group(self, bounded_grouped_model, scheme): + """ValidContext with multi-group get/set round-trips.""" + init_params = bounded_grouped_model.get_values() + with ValidContext(bounded_grouped_model): + params = bounded_grouped_model.get_values(scheme) + bounded_grouped_model.set_values(params) + final_params = bounded_grouped_model.get_values() + for ip, fp in zip(init_params, final_params): + assert backend.module.allclose(ip, fp) + + def test_to_valid_with_group_arg(self, bounded_grouped_model): + """to_valid with explicit group argument.""" + params = bounded_grouped_model.get_values() + v0 = bounded_grouped_model.to_valid(params[0], group=0) + v1 = bounded_grouped_model.to_valid(params[1], group=1) + r0 = bounded_grouped_model.from_valid(v0, group=0) + r1 = bounded_grouped_model.from_valid(v1, group=1) + assert backend.module.allclose(params[0], r0) + assert backend.module.allclose(params[1], r1) + + +# ────────────────────────────────────────────────────────────────────── +# 7. Groups with @forward decorator and fill_params +# ────────────────────────────────────────────────────────────────────── +class TestGroupsForward: + + def test_forward_with_multi_group_params(self): + """@forward method works with multi-group params passed as arg.""" + outer = _make_grouped_outer() + params = outer.get_values() + result = outer.run(10.0, params) + assert result is not None + + def test_forward_with_params_kwarg(self): + """@forward with params= kwarg and multi-group.""" + outer = _make_grouped_outer() + params = outer.get_values() + result = outer.run(10.0, params=params) + assert result is not None + + def test_forward_consistency(self): + """Results should be same whether single or multi-group (same values).""" + inner1 = Inner() + outer1 = Outer(inner1, name="outer1") + outer1.to_dynamic(False) + result_single = outer1.run(10.0, outer1.get_values()) + + outer2 = _make_grouped_outer() + result_multi = outer2.run(10.0, outer2.get_values()) + + assert backend.module.allclose(result_single, result_multi) + + def test_fill_params_multi_group(self): + """fill_params works correctly with multi-group params.""" + outer = _make_grouped_outer() + params = outer.get_values() + with ActiveContext(outer): + outer.fill_params(params) + assert outer.inner.a._value is not None + assert outer.c._value is not None + + +# ────────────────────────────────────────────────────────────────────── +# 8. Groups with ActiveContext +# ────────────────────────────────────────────────────────────────────── +class TestGroupsActiveContext: + + def test_active_context_multi_group(self): + """ActiveContext manages state correctly with multi-group.""" + outer = _make_grouped_outer() + params = outer.get_values() + with ActiveContext(outer): + outer.fill_params(params) + for p in outer.dynamic_params: + assert p._value is not None + for p in outer.dynamic_params: + assert p._value is None + + def test_set_values_blocked_in_active_context(self): + """set_values raises when module is active.""" + outer = _make_grouped_outer() + with ActiveContext(outer): + with pytest.raises(ActiveStateError): + outer.set_values(outer.get_values()) + + def test_nested_active_context_multi_group(self): + """Nested ActiveContext on same module raises error.""" + outer = _make_grouped_outer() + with ActiveContext(outer): + with pytest.raises(ActiveStateError): + with ActiveContext(outer): + pass + + +# ────────────────────────────────────────────────────────────────────── +# 9. Groups with active_cache +# ────────────────────────────────────────────────────────────────────── +class TestGroupsActiveCache: + + def test_active_cache_with_multi_group(self): + """active_cache works correctly in multi-group scenarios.""" + + class CachedModule(Module): + def __init__(self): + super().__init__() + self.a = Param("a", 1.0, shape=()) + self.b = Param("b", 2.0, shape=()) + self.counter = 0 + + @active_cache + def cached_compute(self, x): + self.counter += 1 + return x * 2 + + @forward + def run(self, x, a, b): + c1 = self.cached_compute(x) + c2 = self.cached_compute(x) # Should use cache + return a + b + c1 + c2 + + m = CachedModule() + m.to_dynamic(False) + m.b.group = 1 + params = m.get_values() + m.counter = 0 + result = m.run(3.0, params) + assert m.counter == 1 # cached_compute called only once + + +# ────────────────────────────────────────────────────────────────────── +# 10. Groups with collections (NodeList / NodeTuple) +# ────────────────────────────────────────────────────────────────────── +class TestGroupsCollections: + + def test_groups_with_node_list(self): + """Parameter groups work with NodeList containers.""" + + class Listed(Module): + def __init__(self, workers, name=None): + super().__init__(name) + self.workers = workers + self.p = Param("p", 1.0, shape=()) + + @forward + def run(self, x, p): + return p + sum(w.compute(x) for w in self.workers) + + w1 = Inner(name="w1") + w2 = Inner(name="w2") + listed = Listed([w1, w2], name="listed") + listed.to_dynamic(False) + w1.a.group = 1 + assert listed.dynamic_param_groups == (0, 1) + params = listed.get_values() + assert isinstance(params, list) and len(params) == 2 + result = listed.run(1.0, params) + assert result is not None + + def test_groups_with_node_tuple(self): + """Parameter groups work with NodeTuple containers.""" + m = Module("m") + p1 = Param("p1", 1.0, dynamic=True, group=0) + p2 = Param("p2", 2.0, dynamic=True, group=1) + m.tup = NodeTuple((p1, p2), name="tup") + assert m.dynamic_param_groups == (0, 1) + + def test_collection_dynamic_param_groups(self): + """NodeList/NodeTuple expose dynamic_param_groups.""" + w1 = Inner(name="w1") + w2 = Inner(name="w2") + w1.to_dynamic(False) + w2.to_dynamic(False) + w1.a.group = 1 + nl = NodeList([w1, w2], name="nl") + assert 1 in nl.dynamic_param_groups + + +# ────────────────────────────────────────────────────────────────────── +# 11. Groups with pointer params +# ────────────────────────────────────────────────────────────────────── +class TestGroupsPointerParams: + + def test_pointer_param_group(self): + """Pointer and dynamic params tracked correctly with groups.""" + + class PtrModule(Module): + def __init__(self, name=None): + super().__init__(name) + self.a = Param("a", 1.0, shape=()) + self.b = Param("b", 2.0, shape=()) + + @forward + def run(self, a, b): + return a + b + + m1 = PtrModule("m1") + m2 = PtrModule("m2") + + class Top(Module): + def __init__(self, m1, m2, name=None): + super().__init__(name) + self.m1 = m1 + self.m2 = m2 + + @forward + def run(self, x): + return self.m1.run() + self.m2.run() + x + + m2.a = m1.a + top = Top(m1, m2, name="top") + top.to_dynamic(False) + m1.b.group = 1 + + assert m1.a in top.pointer_params or m1.a in top.dynamic_params + params = top.get_values() + result = top.run(0.0, params) + assert result is not None + + +# ────────────────────────────────────────────────────────────────────── +# 12. Groups with dynamic/static transitions +# ────────────────────────────────────────────────────────────────────── +class TestGroupsDynamicStatic: + + def test_static_params_not_in_groups(self): + """Static params are excluded from dynamic_param_groups.""" + outer = _make_grouped_outer() + assert outer.dynamic_param_groups == (0, 1) + outer.c.to_static() + outer.d.to_static() + assert outer.dynamic_param_groups == (0,) + + def test_to_dynamic_restores_groups(self): + """Converting back to dynamic preserves group membership.""" + outer = _make_grouped_outer() + outer.c.to_static() + assert 1 in outer.dynamic_param_groups # d still in group 1 + outer.c.to_dynamic() + assert outer.c.group == 1 + + def test_to_static_all_then_dynamic(self): + """to_static(False) then to_dynamic restores groups.""" + inner = Inner() + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + inner.a.group = 1 + outer.c.group = 2 + + outer.to_static(False) + assert len(outer.dynamic_params) == 0 + assert len(outer.dynamic_param_groups) == 0 + + outer.to_dynamic(False) + assert inner.a.group == 1 + assert outer.c.group == 2 + assert outer.dynamic_param_groups == (0, 1, 2) + + def test_get_values_after_all_static(self): + """get_values returns empty when all params static.""" + outer = _make_grouped_outer() + outer.to_static(False) + vals = outer.get_values() + assert len(vals) == 0 + + +# ────────────────────────────────────────────────────────────────────── +# 13. Groups with OverrideParam +# ────────────────────────────────────────────────────────────────────── +class TestGroupsOverrideParam: + + def test_override_in_multi_group(self): + """OverrideParam works correctly when param is in a group.""" + + class TestSim(Module): + def __init__(self): + super().__init__() + self.a = Param("a", 3.0) + self.b = Param("b", 5.0) + + @forward + def run(self, a, b): + return a + b + + m = TestSim() + m.to_dynamic(False) + m.b.group = 1 + + params = m.get_values() + with ActiveContext(m): + m.fill_params(params) + assert m.a._value is not None + with OverrideParam(m.a, backend.make_array(10.0)): + assert m.a._value.item() == 10.0 + assert m.a._value.item() == 3.0 + + +# ────────────────────────────────────────────────────────────────────── +# 14. Edge cases +# ────────────────────────────────────────────────────────────────────── +class TestGroupsEdgeCases: + + def test_many_groups(self): + """Large number of groups works correctly.""" + + class ManyParams(Module): + def __init__(self, n): + super().__init__() + for i in range(n): + setattr( + self, f"p{i}", Param(f"p{i}", float(i), shape=(), dynamic=True, group=i) + ) + + m = ManyParams(10) + assert m.dynamic_param_groups == tuple(range(10)) + params = m.get_values() + assert len(params) == 10 + + def test_non_contiguous_groups(self): + """Groups with gaps (e.g. 0, 5, 10) work correctly.""" + m = Module("m") + m.a = Param("a", 1.0, dynamic=True, group=0) + m.b = Param("b", 2.0, dynamic=True, group=5) + m.c = Param("c", 3.0, dynamic=True, group=10) + assert m.dynamic_param_groups == (0, 5, 10) + params = m.get_values() + assert len(params) == 3 + + def test_reassign_all_to_same_group(self): + """Reassigning all params to same group collapses to single-group.""" + outer = _make_grouped_outer() + assert len(outer.dynamic_param_groups) == 2 + outer.c.group = 0 + outer.d.group = 0 + assert outer.dynamic_param_groups == (0,) + vals = outer.get_values("array") + assert isinstance(vals, backend.array_type) + + def test_empty_group_after_static(self): + """Group that becomes empty after to_static is removed from groups.""" + outer = _make_grouped_outer() + assert 1 in outer.dynamic_param_groups + outer.c.to_static() + outer.d.to_static() + assert 1 not in outer.dynamic_param_groups + + def test_single_param_per_group(self): + """Each group with a single param.""" + m = Module("m") + m.a = Param("a", 1.0, dynamic=True, group=0) + m.b = Param("b", 2.0, dynamic=True, group=1) + m.c = Param("c", 3.0, dynamic=True, group=2) + params = m.get_values("array") + assert len(params) == 3 + for p in params: + assert p.shape == (1,) + + def test_forward_no_params_multi_group(self): + """Forward with no dynamic params after all static, multi-group.""" + outer = _make_grouped_outer() + outer.to_static(False) + result = outer.run(10.0) + assert result is not None + + def test_group_with_none_value_param_raises(self): + """Param with None value in a group raises on get_values.""" + m = Module("m") + m.a = Param("a", 1.0, dynamic=True, group=0) + m.b = Param("b", None, dynamic=True, group=1) + with pytest.raises(ParamConfigurationError): + m.get_values() + + +# ────────────────────────────────────────────────────────────────────── +# 15. Groups with ValidContext in @forward +# ────────────────────────────────────────────────────────────────────── +class TestGroupsValidForward: + + def test_valid_context_forward_multi_group(self): + """ValidContext with @forward and multi-group params.""" + + class BoundedSim(Module): + def __init__(self): + super().__init__() + self.a = Param("a", 1.0, shape=(), valid=(0, 10)) + self.b = Param("b", 2.0, shape=(), valid=(0, 10)) + + @forward + def run(self, x, a, b): + return x + a + b + + m = BoundedSim() + m.to_dynamic(False) + m.b.group = 1 + + init_params = m.get_values() + valid = m.to_valid(init_params) + recovered = m.from_valid(valid) + for ip, rp in zip(init_params, recovered): + assert backend.module.allclose(ip, rp) + + with ValidContext(m): + valid_params = m.get_values() + result = m.run(5.0, valid_params) + assert result is not None + + def test_valid_context_preserves_state(self): + """ValidContext doesn't leak state after exit.""" + outer = _make_grouped_outer() + outer.inner.a.valid = (0, 10) + outer.inner.b.valid = (0, 10) + outer.c.valid = (0, 10) + outer.d.valid = (0, 10) + + assert not outer.valid_context + with ValidContext(outer): + assert outer.valid_context + assert not outer.valid_context + + +# ────────────────────────────────────────────────────────────────────── +# 16. Groups with string representation +# ────────────────────────────────────────────────────────────────────── +class TestGroupsStr: + + def test_str_multi_group(self): + """String representation includes all params across groups.""" + outer = _make_grouped_outer() + result = str(outer) + for node in outer.topological_ordering(): + assert node.name in result + + def test_param_order_str(self): + """param_order produces readable multi-group output.""" + outer = _make_grouped_outer() + order = outer.param_order() + assert "c" in order + assert "d" in order + assert "a" in order + assert "b" in order From 8a495edae689135e5230a8cf2c97f6bbd02a63bb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:18:43 +0000 Subject: [PATCH 4/5] Address review feedback: add match assertion to error test Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- tests/test_param_groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_param_groups.py b/tests/test_param_groups.py index 5daafff..eb8c04e 100644 --- a/tests/test_param_groups.py +++ b/tests/test_param_groups.py @@ -760,7 +760,7 @@ def test_group_with_none_value_param_raises(self): m = Module("m") m.a = Param("a", 1.0, dynamic=True, group=0) m.b = Param("b", None, dynamic=True, group=1) - with pytest.raises(ParamConfigurationError): + with pytest.raises(ParamConfigurationError, match="b"): m.get_values() From d35d518bf4137b92222253e6d5c43482b76e3c6b Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 15:32:53 -0400 Subject: [PATCH 5/5] rewrite tests in standard format --- tests/test_param_groups.py | 1309 ++++++++++++++++++------------------ 1 file changed, 669 insertions(+), 640 deletions(-) diff --git a/tests/test_param_groups.py b/tests/test_param_groups.py index eb8c04e..54f15bc 100644 --- a/tests/test_param_groups.py +++ b/tests/test_param_groups.py @@ -51,7 +51,8 @@ def run(self, x, c, d): return backend.sum(x + c + d) + self.inner.compute(x) -def _make_grouped_outer(): +@pytest.fixture +def grouped_model(): """Create an Outer model with all params dynamic and group split.""" inner = Inner() outer = Outer(inner, name="outer") @@ -64,770 +65,798 @@ def _make_grouped_outer(): # ────────────────────────────────────────────────────────────────────── # 1. Basic group creation and graph tracking # ────────────────────────────────────────────────────────────────────── -class TestGroupBasics: - - def test_default_group(self): - """All dynamic params default to group 0.""" - m = Outer(Inner()) - m.to_dynamic(False) - assert all(p.group == 0 for p in m.dynamic_params) - assert m.dynamic_param_groups == (0,) - - def test_assign_group_at_creation(self): - """Groups assigned at Param creation.""" - p = Param("p", 1.0, group=2) - assert p.group == 2 - - def test_assign_group_later(self): - """Groups can be reassigned after creation.""" - m = Outer(Inner()) - m.to_dynamic(False) - m.c.group = 1 - assert m.c.group == 1 - assert 1 in m.dynamic_param_groups - - def test_multiple_groups(self): - """Multiple distinct groups tracked correctly.""" - m = Outer(Inner()) - m.to_dynamic(False) - m.inner.a.group = 0 - m.inner.b.group = 1 - m.c.group = 2 - m.d.group = 3 - assert m.dynamic_param_groups == (0, 1, 2, 3) - - def test_group_must_be_int(self): - """Group must be an integer.""" - p = Param("p", 1.0) - with pytest.raises(AssertionError): - p.group = 1.5 - with pytest.raises(AssertionError): - p.group = "bad" - - def test_reassign_group_updates_graph(self): - """Changing group triggers update_graph on parents.""" - m = Outer(Inner()) - m.to_dynamic(False) - assert m.dynamic_param_groups == (0,) - m.c.group = 5 - assert 5 in m.dynamic_param_groups - m.c.group = 0 - assert m.dynamic_param_groups == (0,) - - def test_param_order_by_group(self): - """param_order returns params organized by group.""" - m = _make_grouped_outer() - order = m.param_order() - assert isinstance(order, str) - lines = order.strip().split("\n") - assert len(lines) == 2 +def test_default_group(): + """All dynamic params default to group 0.""" + m = Outer(Inner()) + m.to_dynamic(False) + assert all(p.group == 0 for p in m.dynamic_params) + assert m.dynamic_param_groups == (0,) + + +def test_assign_group_at_creation(): + """Groups assigned at Param creation.""" + p = Param("p", 1.0, group=2) + assert p.group == 2 + + +def test_assign_group_later(): + """Groups can be reassigned after creation.""" + m = Outer(Inner()) + m.to_dynamic(False) + m.c.group = 1 + assert m.c.group == 1 + assert 1 in m.dynamic_param_groups + + +def test_multiple_groups(): + """Multiple distinct groups tracked correctly.""" + m = Outer(Inner()) + m.to_dynamic(False) + m.inner.a.group = 0 + m.inner.b.group = 1 + m.c.group = 2 + m.d.group = 3 + assert m.dynamic_param_groups == (0, 1, 2, 3) + + +def test_group_must_be_int(): + """Group must be an integer.""" + p = Param("p", 1.0) + with pytest.raises(AssertionError): + p.group = 1.5 + with pytest.raises(AssertionError): + p.group = "bad" + + +def test_reassign_group_updates_graph(): + """Changing group triggers update_graph on parents.""" + m = Outer(Inner()) + m.to_dynamic(False) + assert m.dynamic_param_groups == (0,) + m.c.group = 5 + assert 5 in m.dynamic_param_groups + m.c.group = 0 + assert m.dynamic_param_groups == (0,) + + +def test_param_order_by_group(grouped_model): + """param_order returns params organized by group.""" + m = grouped_model + order = m.param_order() + assert isinstance(order, str) + lines = order.strip().split("\n") + assert len(lines) == 2 # ────────────────────────────────────────────────────────────────────── # 2. Groups in hierarchical models # ────────────────────────────────────────────────────────────────────── -class TestGroupsHierarchical: - - def test_groups_at_top_level(self): - """Groups assigned at top-level module propagate correctly.""" - outer = _make_grouped_outer() - assert outer.dynamic_param_groups == (0, 1) - group0_params = [p for p in outer.dynamic_params if p.group == 0] - group1_params = [p for p in outer.dynamic_params if p.group == 1] - assert len(group0_params) == 2 # inner.a, inner.b - assert len(group1_params) == 2 # outer.c, outer.d - - def test_groups_within_hierarchy(self): - """Groups assigned to params within sub-modules.""" - inner = Inner() - outer = Outer(inner, name="outer") - outer.to_dynamic(False) - inner.a.group = 1 - assert outer.dynamic_param_groups == (0, 1) - group1 = [p for p in outer.dynamic_params if p.group == 1] - assert len(group1) == 1 - assert group1[0] is inner.a - - def test_groups_mixed_hierarchy(self): - """Groups assigned across different levels of hierarchy.""" - inner = Inner() - outer = Outer(inner, name="outer") - outer.to_dynamic(False) - inner.a.group = 1 - outer.c.group = 1 - assert outer.dynamic_param_groups == (0, 1) - group0 = [p for p in outer.dynamic_params if p.group == 0] - group1 = [p for p in outer.dynamic_params if p.group == 1] - assert len(group0) == 2 # inner.b, outer.d - assert len(group1) == 2 # inner.a, outer.c - - def test_deep_hierarchy(self): - """Groups work with deeply nested modules.""" - - class Deep(Module): - def __init__(self, child=None, name=None): - super().__init__(name) - self.p = Param("p", 1.0, shape=()) - if child is not None: - self.child = child - - @forward - def go(self, p): - if hasattr(self, "child"): - return p + self.child.go() - return p - - level3 = Deep(name="level3") - level2 = Deep(level3, name="level2") - level1 = Deep(level2, name="level1") - level1.to_dynamic(False) - - level1.p.group = 0 - level2.p.group = 1 - level3.p.group = 2 - - assert level1.dynamic_param_groups == (0, 1, 2) - - p0 = level1.get_values() - assert len(p0) == 3 - result = level1.go(p0) - assert backend.module.allclose(result, backend.make_array(3.0)) + + +def test_groups_at_top_level(grouped_model): + """Groups assigned at top-level module propagate correctly.""" + outer = grouped_model + assert outer.dynamic_param_groups == (0, 1) + group0_params = [p for p in outer.dynamic_params if p.group == 0] + group1_params = [p for p in outer.dynamic_params if p.group == 1] + assert len(group0_params) == 2 # inner.a, inner.b + assert len(group1_params) == 2 # outer.c, outer.d + + +def test_groups_within_hierarchy(): + """Groups assigned to params within sub-modules.""" + inner = Inner() + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + inner.a.group = 1 + assert outer.dynamic_param_groups == (0, 1) + group1 = [p for p in outer.dynamic_params if p.group == 1] + assert len(group1) == 1 + assert group1[0] is inner.a + + +def test_groups_mixed_hierarchy(): + """Groups assigned across different levels of hierarchy.""" + inner = Inner() + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + inner.a.group = 1 + outer.c.group = 1 + assert outer.dynamic_param_groups == (0, 1) + group0 = [p for p in outer.dynamic_params if p.group == 0] + group1 = [p for p in outer.dynamic_params if p.group == 1] + assert len(group0) == 2 # inner.b, outer.d + assert len(group1) == 2 # inner.a, outer.c + + +def test_deep_hierarchy(): + """Groups work with deeply nested modules.""" + + class Deep(Module): + def __init__(self, child=None, name=None): + super().__init__(name) + self.p = Param("p", 1.0, shape=()) + if child is not None: + self.child = child + + @forward + def go(self, p): + if hasattr(self, "child"): + return p + self.child.go() + return p + + level3 = Deep(name="level3") + level2 = Deep(level3, name="level2") + level1 = Deep(level2, name="level1") + level1.to_dynamic(False) + + level1.p.group = 0 + level2.p.group = 1 + level3.p.group = 2 + + assert level1.dynamic_param_groups == (0, 1, 2) + + p0 = level1.get_values() + assert len(p0) == 3 + result = level1.go(p0) + assert backend.module.allclose(result, backend.make_array(3.0)) # ────────────────────────────────────────────────────────────────────── # 3. Groups with get_values / set_values (all schemes) # ────────────────────────────────────────────────────────────────────── -class TestGroupsGetSetValues: - - @pytest.fixture - def grouped_model(self): - return _make_grouped_outer() - - @pytest.mark.parametrize("scheme", ["array", "list", "dict"]) - def test_get_values_multi_group(self, grouped_model, scheme): - """get_values returns list of per-group values when multiple groups.""" - vals = grouped_model.get_values(scheme) - assert isinstance(vals, list) - assert len(vals) == 2 - - @pytest.mark.parametrize("scheme", ["array", "list", "dict"]) - def test_get_values_single_group(self, scheme): - """get_values returns single value when only one group.""" - m = Outer(Inner()) - m.to_dynamic(False) - vals = m.get_values(scheme) - assert not isinstance(vals, list) or scheme == "list" - - def test_get_values_specific_group(self, grouped_model): - """get_values with group= returns only that group's values.""" - v0 = grouped_model.get_values("array", group=0) - v1 = grouped_model.get_values("array", group=1) - assert isinstance(v0, backend.array_type) - assert isinstance(v1, backend.array_type) - # Group 0 has inner.a (1) + inner.b (2) = 3 elements - assert v0.shape[-1] == 3 - # Group 1 has outer.c (1) + outer.d (2) = 3 elements - assert v1.shape[-1] == 3 - - def test_set_values_multi_group_array(self, grouped_model): - """set_values with multi-group using array scheme.""" - vals = grouped_model.get_values("array") - grouped_model.set_values(vals) - vals2 = grouped_model.get_values("array") - for v, v2 in zip(vals, vals2): - assert backend.module.allclose(v, v2) - - def test_set_values_multi_group_list(self, grouped_model): - """set_values with multi-group using list scheme.""" - vals = grouped_model.get_values("list") - grouped_model.set_values(vals) - vals2 = grouped_model.get_values("list") - assert len(vals) == len(vals2) - - def test_set_values_multi_group_dict(self, grouped_model): - """set_values with multi-group using dict scheme.""" - vals = grouped_model.get_values("dict") - grouped_model.set_values(vals) - vals2 = grouped_model.get_values("dict") - assert len(vals) == len(vals2) - - def test_round_trip_values_multi_group(self, grouped_model): - """Set then get values should produce identical results.""" - original = grouped_model.get_values("array") - scaled = [v * 2 for v in original] - grouped_model.set_values(scaled) - retrieved = grouped_model.get_values("array") - for s, r in zip(scaled, retrieved): - assert backend.module.allclose(s, r) - grouped_model.set_values(original) - restored = grouped_model.get_values("array") - for o, r in zip(original, restored): - assert backend.module.allclose(o, r) + + +@pytest.mark.parametrize("scheme", ["array", "list", "dict"]) +def test_get_values_multi_group(grouped_model, scheme): + """get_values returns list of per-group values when multiple groups.""" + vals = grouped_model.get_values(scheme) + assert isinstance(vals, list) + assert len(vals) == 2 + + +@pytest.mark.parametrize("scheme", ["array", "list", "dict"]) +def test_get_values_single_group(scheme): + """get_values returns single value when only one group.""" + m = Outer(Inner()) + m.to_dynamic(False) + vals = m.get_values(scheme) + assert not isinstance(vals, list) or scheme == "list" + + +def test_get_values_specific_group(grouped_model): + """get_values with group= returns only that group's values.""" + v0 = grouped_model.get_values("array", group=0) + v1 = grouped_model.get_values("array", group=1) + assert isinstance(v0, backend.array_type) + assert isinstance(v1, backend.array_type) + # Group 0 has inner.a (1) + inner.b (2) = 3 elements + assert v0.shape[-1] == 3 + # Group 1 has outer.c (1) + outer.d (2) = 3 elements + assert v1.shape[-1] == 3 + + +def test_set_values_multi_group_array(grouped_model): + """set_values with multi-group using array scheme.""" + vals = grouped_model.get_values("array") + grouped_model.set_values(vals) + vals2 = grouped_model.get_values("array") + for v, v2 in zip(vals, vals2): + assert backend.module.allclose(v, v2) + + +def test_set_values_multi_group_list(grouped_model): + """set_values with multi-group using list scheme.""" + vals = grouped_model.get_values("list") + grouped_model.set_values(vals) + vals2 = grouped_model.get_values("list") + assert len(vals) == len(vals2) + + +def test_set_values_multi_group_dict(grouped_model): + """set_values with multi-group using dict scheme.""" + vals = grouped_model.get_values("dict") + grouped_model.set_values(vals) + vals2 = grouped_model.get_values("dict") + assert len(vals) == len(vals2) + + +def test_round_trip_values_multi_group(grouped_model): + """Set then get values should produce identical results.""" + original = grouped_model.get_values("array") + scaled = [v * 2 for v in original] + grouped_model.set_values(scaled) + retrieved = grouped_model.get_values("array") + for s, r in zip(scaled, retrieved): + assert backend.module.allclose(s, r) + grouped_model.set_values(original) + restored = grouped_model.get_values("array") + for o, r in zip(original, restored): + assert backend.module.allclose(o, r) # ────────────────────────────────────────────────────────────────────── # 4. Groups with varying batch dimensions # ────────────────────────────────────────────────────────────────────── -class TestGroupsBatchDims: - def test_batch_dims_independent_per_group(self): - """Batch dimensions of separate groups are independent.""" - outer = _make_grouped_outer() - outer.inner.a = np.ones((3,)) - outer.inner.b = np.ones((3, 2)) +def test_batch_dims_independent_per_group(grouped_model): + """Batch dimensions of separate groups are independent.""" + outer = grouped_model + + outer.inner.a = np.ones((3,)) + outer.inner.b = np.ones((3, 2)) + + vals = outer.get_values("array") + assert len(vals) == 2 + assert vals[0].shape == (3, 3) + assert vals[1].shape == (3,) + - vals = outer.get_values("array") - assert len(vals) == 2 - assert vals[0].shape == (3, 3) - assert vals[1].shape == (3,) +def test_batch_dims_both_groups_different_sizes(grouped_model): + """Both groups batched but with different batch sizes.""" + outer = grouped_model - def test_batch_dims_both_groups_different_sizes(self): - """Both groups batched but with different batch sizes.""" - outer = _make_grouped_outer() + outer.inner.a = np.ones((4,)) + outer.inner.b = np.ones((4, 2)) - outer.inner.a = np.ones((4,)) - outer.inner.b = np.ones((4, 2)) + outer.c = np.ones((7,)) + outer.d = np.ones((7, 2)) - outer.c = np.ones((7,)) - outer.d = np.ones((7, 2)) + vals = outer.get_values("array") + assert len(vals) == 2 + assert vals[0].shape == (4, 3) + assert vals[1].shape == (7, 3) - vals = outer.get_values("array") - assert len(vals) == 2 - assert vals[0].shape == (4, 3) - assert vals[1].shape == (7, 3) - def test_batch_dims_multi_dim_batch(self): - """Multi-dimensional batch shapes per group.""" - outer = _make_grouped_outer() +def test_batch_dims_multi_dim_batch(grouped_model): + """Multi-dimensional batch shapes per group.""" + outer = grouped_model - outer.inner.a = np.ones((2, 3)) - outer.inner.b = np.ones((2, 3, 2)) + outer.inner.a = np.ones((2, 3)) + outer.inner.b = np.ones((2, 3, 2)) - vals = outer.get_values("array") - assert len(vals) == 2 - assert vals[0].shape == (2, 3, 3) - assert vals[1].shape == (3,) + vals = outer.get_values("array") + assert len(vals) == 2 + assert vals[0].shape == (2, 3, 3) + assert vals[1].shape == (3,) # ────────────────────────────────────────────────────────────────────── # 5. Groups with find_param / find_index # ────────────────────────────────────────────────────────────────────── -class TestGroupsFinders: - @pytest.fixture - def grouped_model(self): - return _make_grouped_outer() - def test_find_param_with_group(self, grouped_model): - """find_param with explicit group argument.""" - p0, idx0 = grouped_model.find_param(0, group=0) - assert p0 is grouped_model.inner.a - p1, idx1 = grouped_model.find_param(0, group=1) - assert p1 is grouped_model.c +def test_find_param_with_group(grouped_model): + """find_param with explicit group argument.""" + p0, idx0 = grouped_model.find_param(0, group=0) + assert p0 is grouped_model.inner.a + p1, idx1 = grouped_model.find_param(0, group=1) + assert p1 is grouped_model.c + + +def test_find_index_multi_group(grouped_model): + """find_index returns (group, index) tuple when multiple groups.""" + idx = grouped_model.find_index(grouped_model.inner.a) + assert isinstance(idx, tuple) + assert idx[0] == 0 # group + assert isinstance(idx[1], (int, slice)) + + +def test_find_index_different_groups(grouped_model): + """find_index returns correct group for each param.""" + idx_a = grouped_model.find_index(grouped_model.inner.a) + idx_c = grouped_model.find_index(grouped_model.c) + assert idx_a[0] == 0 + assert idx_c[0] == 1 - def test_find_index_multi_group(self, grouped_model): - """find_index returns (group, index) tuple when multiple groups.""" - idx = grouped_model.find_index(grouped_model.inner.a) - assert isinstance(idx, tuple) - assert idx[0] == 0 # group - assert isinstance(idx[1], (int, slice)) - def test_find_index_different_groups(self, grouped_model): - """find_index returns correct group for each param.""" - idx_a = grouped_model.find_index(grouped_model.inner.a) - idx_c = grouped_model.find_index(grouped_model.c) - assert idx_a[0] == 0 - assert idx_c[0] == 1 +def test_find_index_list_scheme(grouped_model): + """find_index with list scheme returns (group, list_idx).""" + idx = grouped_model.find_index(grouped_model.inner.a, scheme="list") + assert isinstance(idx, tuple) + assert idx[0] == 0 - def test_find_index_list_scheme(self, grouped_model): - """find_index with list scheme returns (group, list_idx).""" - idx = grouped_model.find_index(grouped_model.inner.a, scheme="list") - assert isinstance(idx, tuple) - assert idx[0] == 0 - def test_find_param_out_of_range(self, grouped_model): - """find_param raises IndexError for out-of-range index.""" - with pytest.raises(IndexError): - grouped_model.find_param(100, group=0) +def test_find_param_out_of_range(grouped_model): + """find_param raises IndexError for out-of-range index.""" + with pytest.raises(IndexError): + grouped_model.find_param(100, group=0) - def test_find_index_unknown_param(self, grouped_model): - """find_index raises ValueError for param not in model.""" - with pytest.raises(ValueError): - grouped_model.find_index(Param("unknown", 1.0)) + +def test_find_index_unknown_param(grouped_model): + """find_index raises ValueError for param not in model.""" + with pytest.raises(ValueError): + grouped_model.find_index(Param("unknown", 1.0)) # ────────────────────────────────────────────────────────────────────── # 6. Groups with to_valid / from_valid and ValidContext # ────────────────────────────────────────────────────────────────────── -class TestGroupsValid: - - @pytest.fixture - def bounded_grouped_model(self): - """Model with valid bounds and multiple groups.""" - inner = Inner() - inner.a.valid = (0, 10) - inner.b.valid = (0, 10) - outer = Outer(inner, name="outer") - outer.to_dynamic(False) - outer.c.valid = (0, 10) - outer.d.valid = (0, 10) - outer.c.group = 1 - outer.d.group = 1 - return outer - - def test_to_valid_multi_group(self, bounded_grouped_model): - """to_valid returns list of valid params per group.""" - params = bounded_grouped_model.get_values() - valid = bounded_grouped_model.to_valid(params) - assert isinstance(valid, list) - assert len(valid) == 2 - - def test_from_valid_multi_group(self, bounded_grouped_model): - """from_valid round-trips multi-group params.""" - params = bounded_grouped_model.get_values() - valid = bounded_grouped_model.to_valid(params) - recovered = bounded_grouped_model.from_valid(valid) - for p, r in zip(params, recovered): - assert backend.module.allclose(p, r) - - @pytest.mark.parametrize("scheme", ["array", "list", "dict"]) - def test_valid_context_multi_group(self, bounded_grouped_model, scheme): - """ValidContext with multi-group get/set round-trips.""" - init_params = bounded_grouped_model.get_values() - with ValidContext(bounded_grouped_model): - params = bounded_grouped_model.get_values(scheme) - bounded_grouped_model.set_values(params) - final_params = bounded_grouped_model.get_values() - for ip, fp in zip(init_params, final_params): - assert backend.module.allclose(ip, fp) - - def test_to_valid_with_group_arg(self, bounded_grouped_model): - """to_valid with explicit group argument.""" - params = bounded_grouped_model.get_values() - v0 = bounded_grouped_model.to_valid(params[0], group=0) - v1 = bounded_grouped_model.to_valid(params[1], group=1) - r0 = bounded_grouped_model.from_valid(v0, group=0) - r1 = bounded_grouped_model.from_valid(v1, group=1) - assert backend.module.allclose(params[0], r0) - assert backend.module.allclose(params[1], r1) +@pytest.fixture +def bounded_grouped_model(): + """Model with valid bounds and multiple groups.""" + inner = Inner() + inner.a.valid = (0, 10) + inner.b.valid = (0, 10) + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + outer.c.valid = (0, 10) + outer.d.valid = (0, 10) + outer.c.group = 1 + outer.d.group = 1 + return outer + + +def test_to_valid_multi_group(bounded_grouped_model): + """to_valid returns list of valid params per group.""" + params = bounded_grouped_model.get_values() + valid = bounded_grouped_model.to_valid(params) + assert isinstance(valid, list) + assert len(valid) == 2 + + +def test_from_valid_multi_group(bounded_grouped_model): + """from_valid round-trips multi-group params.""" + params = bounded_grouped_model.get_values() + valid = bounded_grouped_model.to_valid(params) + recovered = bounded_grouped_model.from_valid(valid) + for p, r in zip(params, recovered): + assert backend.module.allclose(p, r) + + +@pytest.mark.parametrize("scheme", ["array", "list", "dict"]) +def test_valid_context_multi_group(bounded_grouped_model, scheme): + """ValidContext with multi-group get/set round-trips.""" + init_params = bounded_grouped_model.get_values() + with ValidContext(bounded_grouped_model): + params = bounded_grouped_model.get_values(scheme) + bounded_grouped_model.set_values(params) + final_params = bounded_grouped_model.get_values() + for ip, fp in zip(init_params, final_params): + assert backend.module.allclose(ip, fp) + + +def test_to_valid_with_group_arg(bounded_grouped_model): + """to_valid with explicit group argument.""" + params = bounded_grouped_model.get_values() + v0 = bounded_grouped_model.to_valid(params[0], group=0) + v1 = bounded_grouped_model.to_valid(params[1], group=1) + r0 = bounded_grouped_model.from_valid(v0, group=0) + r1 = bounded_grouped_model.from_valid(v1, group=1) + assert backend.module.allclose(params[0], r0) + assert backend.module.allclose(params[1], r1) # ────────────────────────────────────────────────────────────────────── # 7. Groups with @forward decorator and fill_params # ────────────────────────────────────────────────────────────────────── -class TestGroupsForward: +def test_forward_with_multi_group_params(grouped_model): + """@forward method works with multi-group params passed as arg.""" + outer = grouped_model + params = outer.get_values() + result = outer.run(10.0, params) + assert result is not None - def test_forward_with_multi_group_params(self): - """@forward method works with multi-group params passed as arg.""" - outer = _make_grouped_outer() - params = outer.get_values() - result = outer.run(10.0, params) - assert result is not None - def test_forward_with_params_kwarg(self): - """@forward with params= kwarg and multi-group.""" - outer = _make_grouped_outer() - params = outer.get_values() - result = outer.run(10.0, params=params) - assert result is not None +def test_forward_with_params_kwarg(grouped_model): + """@forward with params= kwarg and multi-group.""" + outer = grouped_model + params = outer.get_values() + result = outer.run(10.0, params=params) + assert result is not None - def test_forward_consistency(self): - """Results should be same whether single or multi-group (same values).""" - inner1 = Inner() - outer1 = Outer(inner1, name="outer1") - outer1.to_dynamic(False) - result_single = outer1.run(10.0, outer1.get_values()) - outer2 = _make_grouped_outer() - result_multi = outer2.run(10.0, outer2.get_values()) +def test_forward_consistency(grouped_model): + """Results should be same whether single or multi-group (same values).""" + inner1 = Inner() + outer1 = Outer(inner1, name="outer1") + outer1.to_dynamic(False) + result_single = outer1.run(10.0, outer1.get_values()) - assert backend.module.allclose(result_single, result_multi) + outer2 = grouped_model + result_multi = outer2.run(10.0, outer2.get_values()) - def test_fill_params_multi_group(self): - """fill_params works correctly with multi-group params.""" - outer = _make_grouped_outer() - params = outer.get_values() - with ActiveContext(outer): - outer.fill_params(params) - assert outer.inner.a._value is not None - assert outer.c._value is not None + assert backend.module.allclose(result_single, result_multi) + + +def test_fill_params_multi_group(grouped_model): + """fill_params works correctly with multi-group params.""" + outer = grouped_model + params = outer.get_values() + with ActiveContext(outer): + outer.fill_params(params) + assert outer.inner.a._value is not None + assert outer.c._value is not None # ────────────────────────────────────────────────────────────────────── # 8. Groups with ActiveContext # ────────────────────────────────────────────────────────────────────── -class TestGroupsActiveContext: - def test_active_context_multi_group(self): - """ActiveContext manages state correctly with multi-group.""" - outer = _make_grouped_outer() - params = outer.get_values() - with ActiveContext(outer): - outer.fill_params(params) - for p in outer.dynamic_params: - assert p._value is not None + +def test_active_context_multi_group(grouped_model): + """ActiveContext manages state correctly with multi-group.""" + outer = grouped_model + params = outer.get_values() + with ActiveContext(outer): + outer.fill_params(params) for p in outer.dynamic_params: - assert p._value is None + assert p._value is not None + for p in outer.dynamic_params: + assert p._value is None + - def test_set_values_blocked_in_active_context(self): - """set_values raises when module is active.""" - outer = _make_grouped_outer() - with ActiveContext(outer): - with pytest.raises(ActiveStateError): - outer.set_values(outer.get_values()) +def test_set_values_blocked_in_active_context(grouped_model): + """set_values raises when module is active.""" + outer = grouped_model + with ActiveContext(outer): + with pytest.raises(ActiveStateError): + outer.set_values(outer.get_values()) - def test_nested_active_context_multi_group(self): - """Nested ActiveContext on same module raises error.""" - outer = _make_grouped_outer() - with ActiveContext(outer): - with pytest.raises(ActiveStateError): - with ActiveContext(outer): - pass + +def test_nested_active_context_multi_group(grouped_model): + """Nested ActiveContext on same module raises error.""" + outer = grouped_model + with ActiveContext(outer): + with pytest.raises(ActiveStateError): + with ActiveContext(outer): + pass # ────────────────────────────────────────────────────────────────────── # 9. Groups with active_cache # ────────────────────────────────────────────────────────────────────── -class TestGroupsActiveCache: - def test_active_cache_with_multi_group(self): - """active_cache works correctly in multi-group scenarios.""" - class CachedModule(Module): - def __init__(self): - super().__init__() - self.a = Param("a", 1.0, shape=()) - self.b = Param("b", 2.0, shape=()) - self.counter = 0 +def test_active_cache_with_multi_group(): + """active_cache works correctly in multi-group scenarios.""" + + class CachedModule(Module): + def __init__(self): + super().__init__() + self.a = Param("a", 1.0, shape=()) + self.b = Param("b", 2.0, shape=()) + self.counter = 0 - @active_cache - def cached_compute(self, x): - self.counter += 1 - return x * 2 + @active_cache + def cached_compute(self, x): + self.counter += 1 + return x * 2 - @forward - def run(self, x, a, b): - c1 = self.cached_compute(x) - c2 = self.cached_compute(x) # Should use cache - return a + b + c1 + c2 + @forward + def run(self, x, a, b): + c1 = self.cached_compute(x) + c2 = self.cached_compute(x) # Should use cache + return a + b + c1 + c2 - m = CachedModule() - m.to_dynamic(False) - m.b.group = 1 - params = m.get_values() - m.counter = 0 - result = m.run(3.0, params) - assert m.counter == 1 # cached_compute called only once + m = CachedModule() + m.to_dynamic(False) + m.b.group = 1 + params = m.get_values() + m.counter = 0 + result = m.run(3.0, params) + assert m.counter == 1 # cached_compute called only once # ────────────────────────────────────────────────────────────────────── # 10. Groups with collections (NodeList / NodeTuple) # ────────────────────────────────────────────────────────────────────── -class TestGroupsCollections: - - def test_groups_with_node_list(self): - """Parameter groups work with NodeList containers.""" - - class Listed(Module): - def __init__(self, workers, name=None): - super().__init__(name) - self.workers = workers - self.p = Param("p", 1.0, shape=()) - - @forward - def run(self, x, p): - return p + sum(w.compute(x) for w in self.workers) - - w1 = Inner(name="w1") - w2 = Inner(name="w2") - listed = Listed([w1, w2], name="listed") - listed.to_dynamic(False) - w1.a.group = 1 - assert listed.dynamic_param_groups == (0, 1) - params = listed.get_values() - assert isinstance(params, list) and len(params) == 2 - result = listed.run(1.0, params) - assert result is not None - - def test_groups_with_node_tuple(self): - """Parameter groups work with NodeTuple containers.""" - m = Module("m") - p1 = Param("p1", 1.0, dynamic=True, group=0) - p2 = Param("p2", 2.0, dynamic=True, group=1) - m.tup = NodeTuple((p1, p2), name="tup") - assert m.dynamic_param_groups == (0, 1) - - def test_collection_dynamic_param_groups(self): - """NodeList/NodeTuple expose dynamic_param_groups.""" - w1 = Inner(name="w1") - w2 = Inner(name="w2") - w1.to_dynamic(False) - w2.to_dynamic(False) - w1.a.group = 1 - nl = NodeList([w1, w2], name="nl") - assert 1 in nl.dynamic_param_groups + + +def test_groups_with_node_list(): + """Parameter groups work with NodeList containers.""" + + class Listed(Module): + def __init__(self, workers, name=None): + super().__init__(name) + self.workers = workers + self.p = Param("p", 1.0, shape=()) + + @forward + def run(self, x, p): + return p + sum(w.compute(x) for w in self.workers) + + w1 = Inner(name="w1") + w2 = Inner(name="w2") + listed = Listed([w1, w2], name="listed") + listed.to_dynamic(False) + w1.a.group = 1 + assert listed.dynamic_param_groups == (0, 1) + params = listed.get_values() + assert isinstance(params, list) and len(params) == 2 + result = listed.run(1.0, params) + assert result is not None + + +def test_groups_with_node_tuple(): + """Parameter groups work with NodeTuple containers.""" + m = Module("m") + p1 = Param("p1", 1.0, dynamic=True, group=0) + p2 = Param("p2", 2.0, dynamic=True, group=1) + m.tup = NodeTuple((p1, p2), name="tup") + assert m.dynamic_param_groups == (0, 1) + + +def test_collection_dynamic_param_groups(): + """NodeList/NodeTuple expose dynamic_param_groups.""" + w1 = Inner(name="w1") + w2 = Inner(name="w2") + w1.to_dynamic(False) + w2.to_dynamic(False) + w1.a.group = 1 + nl = NodeList([w1, w2], name="nl") + assert 1 in nl.dynamic_param_groups # ────────────────────────────────────────────────────────────────────── # 11. Groups with pointer params # ────────────────────────────────────────────────────────────────────── -class TestGroupsPointerParams: - def test_pointer_param_group(self): - """Pointer and dynamic params tracked correctly with groups.""" - class PtrModule(Module): - def __init__(self, name=None): - super().__init__(name) - self.a = Param("a", 1.0, shape=()) - self.b = Param("b", 2.0, shape=()) +def test_pointer_param_group(): + """Pointer and dynamic params tracked correctly with groups.""" + + class PtrModule(Module): + def __init__(self, name=None): + super().__init__(name) + self.a = Param("a", 1.0, shape=()) + self.b = Param("b", 2.0, shape=()) - @forward - def run(self, a, b): - return a + b + @forward + def run(self, a, b): + return a + b - m1 = PtrModule("m1") - m2 = PtrModule("m2") + m1 = PtrModule("m1") + m2 = PtrModule("m2") - class Top(Module): - def __init__(self, m1, m2, name=None): - super().__init__(name) - self.m1 = m1 - self.m2 = m2 + class Top(Module): + def __init__(self, m1, m2, name=None): + super().__init__(name) + self.m1 = m1 + self.m2 = m2 - @forward - def run(self, x): - return self.m1.run() + self.m2.run() + x + @forward + def run(self, x): + return self.m1.run() + self.m2.run() + x - m2.a = m1.a - top = Top(m1, m2, name="top") - top.to_dynamic(False) - m1.b.group = 1 + m2.a = m1.a + top = Top(m1, m2, name="top") + top.to_dynamic(False) + m1.b.group = 1 - assert m1.a in top.pointer_params or m1.a in top.dynamic_params - params = top.get_values() - result = top.run(0.0, params) - assert result is not None + assert m1.a in top.pointer_params or m1.a in top.dynamic_params + params = top.get_values() + result = top.run(0.0, params) + assert result is not None # ────────────────────────────────────────────────────────────────────── # 12. Groups with dynamic/static transitions # ────────────────────────────────────────────────────────────────────── -class TestGroupsDynamicStatic: - - def test_static_params_not_in_groups(self): - """Static params are excluded from dynamic_param_groups.""" - outer = _make_grouped_outer() - assert outer.dynamic_param_groups == (0, 1) - outer.c.to_static() - outer.d.to_static() - assert outer.dynamic_param_groups == (0,) - - def test_to_dynamic_restores_groups(self): - """Converting back to dynamic preserves group membership.""" - outer = _make_grouped_outer() - outer.c.to_static() - assert 1 in outer.dynamic_param_groups # d still in group 1 - outer.c.to_dynamic() - assert outer.c.group == 1 - - def test_to_static_all_then_dynamic(self): - """to_static(False) then to_dynamic restores groups.""" - inner = Inner() - outer = Outer(inner, name="outer") - outer.to_dynamic(False) - inner.a.group = 1 - outer.c.group = 2 - - outer.to_static(False) - assert len(outer.dynamic_params) == 0 - assert len(outer.dynamic_param_groups) == 0 - - outer.to_dynamic(False) - assert inner.a.group == 1 - assert outer.c.group == 2 - assert outer.dynamic_param_groups == (0, 1, 2) - - def test_get_values_after_all_static(self): - """get_values returns empty when all params static.""" - outer = _make_grouped_outer() - outer.to_static(False) - vals = outer.get_values() - assert len(vals) == 0 + + +def test_static_params_not_in_groups(grouped_model): + """Static params are excluded from dynamic_param_groups.""" + outer = grouped_model + assert outer.dynamic_param_groups == (0, 1) + outer.c.to_static() + outer.d.to_static() + assert outer.dynamic_param_groups == (0,) + + +def test_to_dynamic_restores_groups(grouped_model): + """Converting back to dynamic preserves group membership.""" + outer = grouped_model + outer.c.to_static() + assert 1 in outer.dynamic_param_groups # d still in group 1 + outer.c.to_dynamic() + assert outer.c.group == 1 + + +def test_to_static_all_then_dynamic(): + """to_static(False) then to_dynamic restores groups.""" + inner = Inner() + outer = Outer(inner, name="outer") + outer.to_dynamic(False) + inner.a.group = 1 + outer.c.group = 2 + + outer.to_static(False) + assert len(outer.dynamic_params) == 0 + assert len(outer.dynamic_param_groups) == 0 + + outer.to_dynamic(False) + assert inner.a.group == 1 + assert outer.c.group == 2 + assert outer.dynamic_param_groups == (0, 1, 2) + + +def test_get_values_after_all_static(grouped_model): + """get_values returns empty when all params static.""" + outer = grouped_model + outer.to_static(False) + vals = outer.get_values() + assert len(vals) == 0 # ────────────────────────────────────────────────────────────────────── # 13. Groups with OverrideParam # ────────────────────────────────────────────────────────────────────── -class TestGroupsOverrideParam: - def test_override_in_multi_group(self): - """OverrideParam works correctly when param is in a group.""" - class TestSim(Module): - def __init__(self): - super().__init__() - self.a = Param("a", 3.0) - self.b = Param("b", 5.0) +def test_override_in_multi_group(): + """OverrideParam works correctly when param is in a group.""" + + class TestSim(Module): + def __init__(self): + super().__init__() + self.a = Param("a", 3.0) + self.b = Param("b", 5.0) - @forward - def run(self, a, b): - return a + b + @forward + def run(self, a, b): + return a + b - m = TestSim() - m.to_dynamic(False) - m.b.group = 1 + m = TestSim() + m.to_dynamic(False) + m.b.group = 1 - params = m.get_values() - with ActiveContext(m): - m.fill_params(params) - assert m.a._value is not None - with OverrideParam(m.a, backend.make_array(10.0)): - assert m.a._value.item() == 10.0 - assert m.a._value.item() == 3.0 + params = m.get_values() + with ActiveContext(m): + m.fill_params(params) + assert m.a._value is not None + with OverrideParam(m.a, backend.make_array(10.0)): + assert m.a._value.item() == 10.0 + assert m.a._value.item() == 3.0 # ────────────────────────────────────────────────────────────────────── # 14. Edge cases # ────────────────────────────────────────────────────────────────────── -class TestGroupsEdgeCases: - - def test_many_groups(self): - """Large number of groups works correctly.""" - - class ManyParams(Module): - def __init__(self, n): - super().__init__() - for i in range(n): - setattr( - self, f"p{i}", Param(f"p{i}", float(i), shape=(), dynamic=True, group=i) - ) - - m = ManyParams(10) - assert m.dynamic_param_groups == tuple(range(10)) - params = m.get_values() - assert len(params) == 10 - - def test_non_contiguous_groups(self): - """Groups with gaps (e.g. 0, 5, 10) work correctly.""" - m = Module("m") - m.a = Param("a", 1.0, dynamic=True, group=0) - m.b = Param("b", 2.0, dynamic=True, group=5) - m.c = Param("c", 3.0, dynamic=True, group=10) - assert m.dynamic_param_groups == (0, 5, 10) - params = m.get_values() - assert len(params) == 3 - - def test_reassign_all_to_same_group(self): - """Reassigning all params to same group collapses to single-group.""" - outer = _make_grouped_outer() - assert len(outer.dynamic_param_groups) == 2 - outer.c.group = 0 - outer.d.group = 0 - assert outer.dynamic_param_groups == (0,) - vals = outer.get_values("array") - assert isinstance(vals, backend.array_type) - - def test_empty_group_after_static(self): - """Group that becomes empty after to_static is removed from groups.""" - outer = _make_grouped_outer() - assert 1 in outer.dynamic_param_groups - outer.c.to_static() - outer.d.to_static() - assert 1 not in outer.dynamic_param_groups - - def test_single_param_per_group(self): - """Each group with a single param.""" - m = Module("m") - m.a = Param("a", 1.0, dynamic=True, group=0) - m.b = Param("b", 2.0, dynamic=True, group=1) - m.c = Param("c", 3.0, dynamic=True, group=2) - params = m.get_values("array") - assert len(params) == 3 - for p in params: - assert p.shape == (1,) - - def test_forward_no_params_multi_group(self): - """Forward with no dynamic params after all static, multi-group.""" - outer = _make_grouped_outer() - outer.to_static(False) - result = outer.run(10.0) - assert result is not None - - def test_group_with_none_value_param_raises(self): - """Param with None value in a group raises on get_values.""" - m = Module("m") - m.a = Param("a", 1.0, dynamic=True, group=0) - m.b = Param("b", None, dynamic=True, group=1) - with pytest.raises(ParamConfigurationError, match="b"): - m.get_values() + + +def test_many_groups(): + """Large number of groups works correctly.""" + + class ManyParams(Module): + def __init__(self, n): + super().__init__() + for i in range(n): + setattr(self, f"p{i}", Param(f"p{i}", float(i), shape=(), dynamic=True, group=i)) + + m = ManyParams(10) + assert m.dynamic_param_groups == tuple(range(10)) + params = m.get_values() + assert len(params) == 10 + + +def test_non_contiguous_groups(): + """Groups with gaps (e.g. 0, 5, 10) work correctly.""" + m = Module("m") + m.a = Param("a", 1.0, dynamic=True, group=0) + m.b = Param("b", 2.0, dynamic=True, group=5) + m.c = Param("c", 3.0, dynamic=True, group=10) + assert m.dynamic_param_groups == (0, 5, 10) + params = m.get_values() + assert len(params) == 3 + + +def test_reassign_all_to_same_group(grouped_model): + """Reassigning all params to same group collapses to single-group.""" + outer = grouped_model + assert len(outer.dynamic_param_groups) == 2 + outer.c.group = 0 + outer.d.group = 0 + assert outer.dynamic_param_groups == (0,) + vals = outer.get_values("array") + assert isinstance(vals, backend.array_type) + + +def test_empty_group_after_static(grouped_model): + """Group that becomes empty after to_static is removed from groups.""" + outer = grouped_model + assert 1 in outer.dynamic_param_groups + outer.c.to_static() + outer.d.to_static() + assert 1 not in outer.dynamic_param_groups + + +def test_single_param_per_group(): + """Each group with a single param.""" + m = Module("m") + m.a = Param("a", 1.0, dynamic=True, group=0) + m.b = Param("b", 2.0, dynamic=True, group=1) + m.c = Param("c", 3.0, dynamic=True, group=2) + params = m.get_values("array") + assert len(params) == 3 + for p in params: + assert p.shape == (1,) + + +def test_forward_no_params_multi_group(grouped_model): + """Forward with no dynamic params after all static, multi-group.""" + outer = grouped_model + outer.to_static(False) + result = outer.run(10.0) + assert result is not None + + +def test_group_with_none_value_param_raises(): + """Param with None value in a group raises on get_values.""" + m = Module("m") + m.a = Param("a", 1.0, dynamic=True, group=0) + m.b = Param("b", None, dynamic=True, group=1) + with pytest.raises(ParamConfigurationError, match="b"): + m.get_values() # ────────────────────────────────────────────────────────────────────── # 15. Groups with ValidContext in @forward # ────────────────────────────────────────────────────────────────────── -class TestGroupsValidForward: - def test_valid_context_forward_multi_group(self): - """ValidContext with @forward and multi-group params.""" - class BoundedSim(Module): - def __init__(self): - super().__init__() - self.a = Param("a", 1.0, shape=(), valid=(0, 10)) - self.b = Param("b", 2.0, shape=(), valid=(0, 10)) +def test_valid_context_forward_multi_group(): + """ValidContext with @forward and multi-group params.""" + + class BoundedSim(Module): + def __init__(self): + super().__init__() + self.a = Param("a", 1.0, shape=(), valid=(0, 10)) + self.b = Param("b", 2.0, shape=(), valid=(0, 10)) - @forward - def run(self, x, a, b): - return x + a + b + @forward + def run(self, x, a, b): + return x + a + b - m = BoundedSim() - m.to_dynamic(False) - m.b.group = 1 + m = BoundedSim() + m.to_dynamic(False) + m.b.group = 1 - init_params = m.get_values() - valid = m.to_valid(init_params) - recovered = m.from_valid(valid) - for ip, rp in zip(init_params, recovered): - assert backend.module.allclose(ip, rp) + init_params = m.get_values() + valid = m.to_valid(init_params) + recovered = m.from_valid(valid) + for ip, rp in zip(init_params, recovered): + assert backend.module.allclose(ip, rp) - with ValidContext(m): - valid_params = m.get_values() - result = m.run(5.0, valid_params) - assert result is not None + with ValidContext(m): + valid_params = m.get_values() + result = m.run(5.0, valid_params) + assert result is not None - def test_valid_context_preserves_state(self): - """ValidContext doesn't leak state after exit.""" - outer = _make_grouped_outer() - outer.inner.a.valid = (0, 10) - outer.inner.b.valid = (0, 10) - outer.c.valid = (0, 10) - outer.d.valid = (0, 10) - assert not outer.valid_context - with ValidContext(outer): - assert outer.valid_context - assert not outer.valid_context +def test_valid_context_preserves_state(grouped_model): + """ValidContext doesn't leak state after exit.""" + outer = grouped_model + outer.inner.a.valid = (0, 10) + outer.inner.b.valid = (0, 10) + outer.c.valid = (0, 10) + outer.d.valid = (0, 10) + + assert not outer.valid_context + with ValidContext(outer): + assert outer.valid_context + assert not outer.valid_context # ────────────────────────────────────────────────────────────────────── # 16. Groups with string representation # ────────────────────────────────────────────────────────────────────── -class TestGroupsStr: - - def test_str_multi_group(self): - """String representation includes all params across groups.""" - outer = _make_grouped_outer() - result = str(outer) - for node in outer.topological_ordering(): - assert node.name in result - - def test_param_order_str(self): - """param_order produces readable multi-group output.""" - outer = _make_grouped_outer() - order = outer.param_order() - assert "c" in order - assert "d" in order - assert "a" in order - assert "b" in order + + +def test_str_multi_group(grouped_model): + """String representation includes all params across groups.""" + outer = grouped_model + result = str(outer) + for node in outer.topological_ordering(): + assert node.name in result + + +def test_param_order_str(grouped_model): + """param_order produces readable multi-group output.""" + outer = grouped_model + order = outer.param_order() + assert "c" in order + assert "d" in order + assert "a" in order + assert "b" in order