From 184b6003f24434a0c27514340ee91fe9c125f58b Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Thu, 21 May 2026 18:38:19 +0200 Subject: [PATCH 1/3] chore: remove double buffering simplify the accumulator by removing double buffering, it did nothing for us because .tobytes in the client is a copy and isn't async. If we do both of these in future then it may be worth adding back --- docs/src/guides/trajectory-accumulator.md | 16 ++++++-------- python/echo/trajectory_accumulator.py | 23 ++++++++++----------- python/tests/test_trajectory_accumulator.py | 23 +++++---------------- 3 files changed, 22 insertions(+), 40 deletions(-) diff --git a/docs/src/guides/trajectory-accumulator.md b/docs/src/guides/trajectory-accumulator.md index 6c8de6e..4b90a05 100644 --- a/docs/src/guides/trajectory-accumulator.md +++ b/docs/src/guides/trajectory-accumulator.md @@ -55,21 +55,17 @@ for _ in range(num_rollouts): `TrajectoryAccumulator` is just pre-allocated numpy arrays plus per-timescale slot counters. `add()` writes into the next free slot for that timescale via -slice assignment; `build()` returns the filled pytree, flips the active -buffer (so the next round writes into a fresh one without allocating), and -resets the counters. +slice assignment; `build()` returns the filled pytree and resets the +counters. The buffer is reused — no allocation per rollout. There's no network or Rust involvement; it's purely a way to amortise the flatten-and-send cost across many environment steps. The pytree it returns goes through the normal `client.send` path. -## Two buffers, no allocation per rollout - -`TrajectoryAccumulator` double-buffers internally: two copies of the pytree, with -`build()` swapping the active one. So while one buffer is being serialised -and sent, the next rollout can start filling the other without any -allocation. This matters when rollouts are short relative to flatten + -network latency. +The tree returned by `build()` aliases the accumulator's internal buffers, +so the next `add()` will overwrite it. The usual `client.send(buf.build())` +pattern is safe because `send` is synchronous and copies the bytes before +returning — don't hold onto the returned tree across further `add()` calls. ## Common pitfalls diff --git a/python/echo/trajectory_accumulator.py b/python/echo/trajectory_accumulator.py index 371364c..f6a7ece 100644 --- a/python/echo/trajectory_accumulator.py +++ b/python/echo/trajectory_accumulator.py @@ -5,7 +5,7 @@ class TrajectoryAccumulator: - """Multi-timescale accumulator: fixed-size pytree buffer with double-buffering. + """Multi-timescale accumulator: fixed-size pytree buffer. Args: example: Dict with timescale names as top-level keys. The leading @@ -29,12 +29,9 @@ def __init__(self, example: dict[str, Any]): ) self._counts[name] = leading - # Two copies of the pytree for double-buffering. - self._trees: list[dict[str, Any]] = [ - {n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()}, - {n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()}, - ] - self._active = 0 + self._tree: dict[str, Any] = { + n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items() + } self._slot: dict[str, int] = {name: 0 for name in example} def add(self, name: str, data: Any) -> None: @@ -52,15 +49,17 @@ def _write_slot(stored, incoming): stored[s:s + 1] = incoming return stored - optree.tree_map_(_write_slot, self._trees[self._active][name], data) + optree.tree_map_(_write_slot, self._tree[name], data) self._slot[name] += 1 def build(self) -> dict[str, Any]: - """Return the filled pytree and flip the active buffer.""" - tree = self._trees[self._active] - self._active = 1 - self._active + """Return the filled pytree and reset slot counters. + + The returned tree aliases internal buffers; callers must finish using + it (e.g. complete the synchronous send) before the next ``add()``. + """ self._slot = {name: 0 for name in self._slot} - return tree + return self._tree def reset(self) -> None: """Reset slot counters without sending (e.g. on episode abort).""" diff --git a/python/tests/test_trajectory_accumulator.py b/python/tests/test_trajectory_accumulator.py index 36d20a6..14cab35 100644 --- a/python/tests/test_trajectory_accumulator.py +++ b/python/tests/test_trajectory_accumulator.py @@ -44,7 +44,7 @@ def test_add_writes_correct_values(self): obs = np.arange(12, dtype=np.float32) buf.add("transition", {"obs": obs, "rew": np.array(7.0, dtype=np.float32)}) - tree = buf._trees[buf._active] + tree = buf._tree np.testing.assert_array_equal(tree["transition"]["obs"][0], obs) np.testing.assert_array_equal(tree["transition"]["rew"][0], 7.0) @@ -58,7 +58,7 @@ def test_add_dtype_cast(self): obs64 = np.ones(12, dtype=np.float64) * 1.5 buf.add("transition", {"obs": obs64, "rew": np.zeros((), dtype=np.float64)}) - tree = buf._trees[buf._active] + tree = buf._tree np.testing.assert_allclose(tree["transition"]["obs"][0], 1.5) def test_add_multiple_slots(self): @@ -69,7 +69,7 @@ def test_add_multiple_slots(self): "rew": np.array(float(i * 10), dtype=np.float32), }) - tree = buf._trees[buf._active] + tree = buf._tree for i in range(N): np.testing.assert_array_equal(tree["transition"]["obs"][i], float(i)) np.testing.assert_array_equal(tree["transition"]["rew"][i], float(i * 10)) @@ -99,24 +99,11 @@ def test_build_contains_correct_data(self): np.testing.assert_array_equal(tree["transition"]["obs"][0], 0.0) np.testing.assert_array_equal(tree["summary"]["ret"], [99.0]) - def test_build_flips_active_buffer(self): + def test_build_resets_slot_counters(self): buf = TrajectoryAccumulator(_EXAMPLE) - active_before = buf._active self._fill(buf) buf.build() - assert buf._active != active_before - - def test_build_data_not_clobbered_by_next_add(self): - buf = TrajectoryAccumulator(_EXAMPLE) - self._fill(buf) - tree = buf.build() - snapshot = tree["transition"]["obs"].copy() - for i in range(N): - buf.add("transition", { - "obs": np.full(12, 999.0, dtype=np.float32), - "rew": np.array(999.0, dtype=np.float32), - }) - np.testing.assert_array_equal(tree["transition"]["obs"], snapshot) + assert all(s == 0 for s in buf._slot.values()) def test_reset_clears_slot_counters(self): buf = TrajectoryAccumulator(_EXAMPLE) From 1d6421631a9ea5034eb5e0d612bf9c0fa3c2a37c Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Thu, 21 May 2026 18:48:01 +0200 Subject: [PATCH 2/3] feat: allow zero dimensional leaves with trajectory accumulator --- python/echo/trajectory_accumulator.py | 2 +- python/tests/test_trajectory_accumulator.py | 68 +++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/python/echo/trajectory_accumulator.py b/python/echo/trajectory_accumulator.py index f6a7ece..46c00dd 100644 --- a/python/echo/trajectory_accumulator.py +++ b/python/echo/trajectory_accumulator.py @@ -46,7 +46,7 @@ def add(self, name: str, data: Any) -> None: ) def _write_slot(stored, incoming): - stored[s:s + 1] = incoming + np.atleast_1d(stored)[s:s + 1] = incoming return stored optree.tree_map_(_write_slot, self._tree[name], data) diff --git a/python/tests/test_trajectory_accumulator.py b/python/tests/test_trajectory_accumulator.py index 14cab35..07c3216 100644 --- a/python/tests/test_trajectory_accumulator.py +++ b/python/tests/test_trajectory_accumulator.py @@ -75,6 +75,74 @@ def test_add_multiple_slots(self): np.testing.assert_array_equal(tree["transition"]["rew"][i], float(i * 10)) +class TestTrajectoryAccumulatorScalarLeaves: + """0-d leaves should be writable without padding to (1,).""" + + def _example(self): + return { + "step": { + "obs": np.empty((3, 4), dtype=np.float32), + "reward": np.empty((3,), dtype=np.float32), + }, + "episode": { + "ret": np.empty((), dtype=np.float32), + "gen": np.empty((), dtype=np.int32), + }, + } + + def test_capacity_inferred_for_zero_d_timescale(self): + buf = TrajectoryAccumulator(self._example()) + assert buf._counts == {"step": 3, "episode": 1} + + def test_write_zero_d_leaves(self): + buf = TrajectoryAccumulator(self._example()) + buf.add("episode", { + "ret": np.array(7.5, dtype=np.float32), + "gen": np.array(42, dtype=np.int32), + }) + tree = buf._tree + assert tree["episode"]["ret"].shape == () + assert tree["episode"]["gen"].shape == () + np.testing.assert_array_equal(tree["episode"]["ret"], 7.5) + np.testing.assert_array_equal(tree["episode"]["gen"], 42) + + def test_mixed_zero_d_and_nd_in_same_build(self): + buf = TrajectoryAccumulator(self._example()) + for i in range(3): + buf.add("step", { + "obs": np.full(4, float(i), dtype=np.float32), + "reward": np.array(float(i * 10), dtype=np.float32), + }) + buf.add("episode", { + "ret": np.array(99.0, dtype=np.float32), + "gen": np.array(5, dtype=np.int32), + }) + tree = buf.build() + + assert tree["step"]["obs"].shape == (3, 4) + assert tree["step"]["reward"].shape == (3,) + assert tree["episode"]["ret"].shape == () + assert tree["episode"]["gen"].shape == () + + for i in range(3): + np.testing.assert_array_equal(tree["step"]["obs"][i], float(i)) + np.testing.assert_array_equal(tree["step"]["reward"][i], float(i * 10)) + np.testing.assert_array_equal(tree["episode"]["ret"], 99.0) + np.testing.assert_array_equal(tree["episode"]["gen"], 5) + + def test_zero_d_timescale_full_after_one_add(self): + buf = TrajectoryAccumulator(self._example()) + buf.add("episode", { + "ret": np.array(1.0, dtype=np.float32), + "gen": np.array(1, dtype=np.int32), + }) + with pytest.raises(IndexError): + buf.add("episode", { + "ret": np.array(2.0, dtype=np.float32), + "gen": np.array(2, dtype=np.int32), + }) + + class TestTrajectoryAccumulatorBuild: def _fill(self, buf): for i in range(N): From f576d60f0036143680547865410b869def638ec9 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Thu, 21 May 2026 23:12:10 +0200 Subject: [PATCH 3/3] feat: distinguish single-item from buffered timescales in TrajectoryAccumulator A timescale is now classified as either: - Buffered: every leaf shares leading dim N > 1; capacity = N, writes go to stored[s:s+1]. - Single-item: every leaf is 0-d or has shape[0] == 1. Capacity = 1, add() replaces the whole leaf via stored[...] = incoming. This makes the accumulator accept natural per-item shapes for one-shot trailing context (bootstrap step, episode return, param generation) without forcing callers to pad a leading 1 onto every leaf and squeeze it back out downstream. Other changes: - IndexError on add() past capacity now reports the timescale name, capacity, and offending index. - Buffered-timescale leading-dim mismatch raises with a hint that the caller can opt into single-item mode by making any leaf 0-d or all leaves shape (1, ...). - _write_slot simplified to a single stored[key] = incoming line, with key = Ellipsis for single-item and slice(s, s+1) for buffered. - Tests cover both detection paths, the mismatch error, the over-index error, and the single-item write semantics. - Docs (guide + API) describe both modes and the detection rule. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/src/api/trajectory-accumulator.md | 10 ++ docs/src/guides/trajectory-accumulator.md | 58 +++++-- python/echo/trajectory_accumulator.py | 60 +++++--- python/tests/test_trajectory_accumulator.py | 161 ++++++++++++++++---- 4 files changed, 223 insertions(+), 66 deletions(-) diff --git a/docs/src/api/trajectory-accumulator.md b/docs/src/api/trajectory-accumulator.md index b9da867..472b647 100644 --- a/docs/src/api/trajectory-accumulator.md +++ b/docs/src/api/trajectory-accumulator.md @@ -5,4 +5,14 @@ client side before calling `client.send`. Useful when a single rollout step produces several arrays at different rates (e.g. one transition per env step, plus a single episode-level statistic per episode). +Each timescale is either: + +- **Buffered** — every leaf shares the same leading dim `N` (`N > 1`); + that's the capacity, filled by `N` `add()` calls. +- **Single-item** — capacity 1; detected when every leaf is 1-d or has + at least one leaf that is 0-d. `add()` replaces the whole leaf. + +See the [guide](../guides/trajectory-accumulator.md) for the rationale and +worked examples. + ::: echo.trajectory_accumulator.TrajectoryAccumulator diff --git a/docs/src/guides/trajectory-accumulator.md b/docs/src/guides/trajectory-accumulator.md index 4b90a05..0d8aca1 100644 --- a/docs/src/guides/trajectory-accumulator.md +++ b/docs/src/guides/trajectory-accumulator.md @@ -17,8 +17,18 @@ samples. ## How it works Construct it with a dict whose top-level keys are timescale names. Each -leaf's leading dimension is the number of `add()` calls expected before the -buffer is ready to send. +timescale is one of two kinds, inferred from the example pytree: + +- **Buffered** — every leaf shares the same leading dim `N` (with `N > 1`). + That leading dim is the timescale's capacity: the buffer fills after `N` + `add()` calls and each call writes into `stored[s:s+1]`. + +- **Single-item** — the timescale holds one trailing piece of context + (e.g. an episode return, a bootstrap step) rather than a buffer of + steps. Detected when at least one leaf is 0-d, *or* all leaves have + `shape[0] == 1`. Capacity is `1` and `add()` replaces the whole leaf, + so non-0-d leaves can carry any per-item shape (apart from the + optional leading 1). ```python import numpy as np @@ -27,12 +37,21 @@ from echo import TrajectoryAccumulator, TcpClient T = 64 # rollout length example = { + # Buffered timescale: leading dim T across every leaf. "step": { "obs": np.zeros((T, 4), dtype=np.float32), - "reward": np.zeros((T, 1), dtype=np.float32), + "reward": np.zeros((T,), dtype=np.float32), }, + # Single-item timescale: all leaves have shape (1, ...) — capacity 1, + # `add()` replaces the whole leaf. "episode": { "return": np.zeros((1,), dtype=np.float32), + "length": np.zeros((1,), dtype=np.float32) + }, + # Single-item timescale: 0-d reward means add() replaces the whole leaf + "final_step": { + "obs": np.zeros((4,), dtype=np.float32), + "reward": np.zeros((), dtype=np.int32), }, } @@ -42,21 +61,26 @@ buf = TrajectoryAccumulator(example) for _ in range(num_rollouts): episode_return = 0.0 + reward = 0.0 + obs = env.reset() for _ in range(T): - obs, reward = env.step(...) buf.add("step", {"obs": obs, "reward": reward}) + obs, reward, ... = env.step(...) episode_return += float(reward) - buf.add("episode", {"return": np.array([episode_return], dtype=np.float32)}) + buf.add("episode", {"return": np.array([episode_return]), "length": np.array([length])}) + buf.add("final_step", {"obs": obs, "reward": reward}) client.send(buf.build()) ``` ## Mental model -`TrajectoryAccumulator` is just pre-allocated numpy arrays plus per-timescale slot -counters. `add()` writes into the next free slot for that timescale via -slice assignment; `build()` returns the filled pytree and resets the -counters. The buffer is reused — no allocation per rollout. +`TrajectoryAccumulator` is just pre-allocated numpy arrays plus per-timescale +slot counters. For buffered timescales, `add()` writes into the next free +slot via slice assignment (`stored[s:s+1] = incoming`); for single-item +timescales it replaces the whole leaf (`stored[...] = incoming`). +`build()` returns the filled pytree and resets the counters. The buffer is +reused — no allocation per rollout. There's no network or Rust involvement; it's purely a way to amortise the flatten-and-send cost across many environment steps. The pytree it returns @@ -69,13 +93,15 @@ returning — don't hold onto the returned tree across further `add()` calls. ## Common pitfalls -- **Leading-dimension mismatch within a timescale.** All leaves under one - timescale key must share the same first axis size. That's what defines - "how many `add` calls before the buffer is full". The constructor checks - this and raises. -- **Adding past capacity.** If you call `add` more than the leading - dimension allows, you get `IndexError`. Call `reset()` if you want to - abort a partial rollout. +- **Leading-dimension mismatch in a *buffered* timescale.** Inside a + buffered timescale, all leaves must share the same first axis size — + that's what defines "how many `add` calls before the buffer is full". + The constructor checks this and raises. If you actually meant a + single-item timescale, make one leaf 0-d or add a leading dim to all leaves. +- **Adding past capacity.** If you call `add` more times than the + timescale's capacity, you get `IndexError` with the timescale name, + capacity, and offending index. Call `reset()` to abort a partial + rollout. - **Dict-only at the top level.** The top-level pytree must be a `dict` with timescale names. Below that, leaves can be any pytree shape that `optree` understands. diff --git a/python/echo/trajectory_accumulator.py b/python/echo/trajectory_accumulator.py index 46c00dd..ae09a64 100644 --- a/python/echo/trajectory_accumulator.py +++ b/python/echo/trajectory_accumulator.py @@ -7,10 +7,25 @@ class TrajectoryAccumulator: """Multi-timescale accumulator: fixed-size pytree buffer. + Per timescale, the example pytree determines how many ``add()`` calls + fit before the buffer is full: + + * **Buffered timescale** — every leaf shares the same leading dim ``N`` + (``N > 1``); the accumulator stores ``N`` per-add items into + ``stored[s:s+1] = incoming`` slot-by-slot. ``N`` becomes the + timescale's capacity. + + * **Single-item timescale** — the timescale holds one trailing piece of + context (e.g. a bootstrap step, an episode return) rather than a + buffer. Detected when at least one leaf is 0-d, or all leaves + have ``shape[0] == 1``. Capacity is ``1``; ``add()`` replaces the + whole leaf, so non-0-d leaves may have any per-item shape (apart + from the optional leading 1). + Args: - example: Dict with timescale names as top-level keys. The leading - dimension of each leaf array is the number of ``add()`` calls - expected before the buffer is ready to send. + example: Dict with timescale names as top-level keys. Each value is + a pytree whose leaves declare the per-timescale layout per the + rule above. """ def __init__(self, example: dict[str, Any]): @@ -18,20 +33,28 @@ def __init__(self, example: dict[str, Any]): raise TypeError("example must be a dict with timescale names as top-level keys") self._counts: dict[str, int] = {} + self._single_item: dict[str, bool] = {} for name, subtree in example.items(): leaves = optree.tree_leaves(subtree) if not leaves: raise ValueError(f"Timescale '{name}' has no array leaves") - leading = leaves[0].shape[0] if leaves[0].ndim > 0 else 1 - if not all((leaf.shape[0] if leaf.ndim > 0 else 1) == leading for leaf in leaves): - raise ValueError( - f"All leaves in timescale '{name}' must share the same leading dimension" - ) - self._counts[name] = leading - - self._tree: dict[str, Any] = { - n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items() - } + + # Single-item: any 0-d leaf OR every leaf with leading dim 1. + if any(leaf.ndim == 0 for leaf in leaves) or all(leaf.shape[0] == 1 for leaf in leaves): + self._counts[name] = 1 + self._single_item[name] = True + else: + leading = [leaf.shape[0] for leaf in leaves] + if not all(s == leading[0] for s in leading): + raise ValueError( + f"All leaves in buffered timescale '{name}' must share the same " + f"leading dimension (got {leading}); make any leaf 0-d or " + f"all shape (1, ...) to mark the timescale single-item instead" + ) + self._counts[name] = leading[0] + self._single_item[name] = False + + self._tree: dict[str, Any] = {n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()} self._slot: dict[str, int] = {name: 0 for name in example} def add(self, name: str, data: Any) -> None: @@ -40,13 +63,14 @@ def add(self, name: str, data: Any) -> None: raise KeyError(f"Unknown timescale '{name}'. Known: {list(self._counts)}") s = self._slot[name] if s >= self._counts[name]: - raise IndexError( - f"Timescale '{name}' is already full ({self._counts[name]} slots). " - "Call reset() or build() before adding more." - ) + raise IndexError(f"Timescale '{name}' has {self._counts[name]} slots, but you tried to add at index {s}") + + # Single-item: replace the whole leaf + # Buffered: write into the next slot of the leading dim. + key = Ellipsis if self._single_item[name] else slice(s, s + 1) def _write_slot(stored, incoming): - np.atleast_1d(stored)[s:s + 1] = incoming + stored[key] = incoming return stored optree.tree_map_(_write_slot, self._tree[name], data) diff --git a/python/tests/test_trajectory_accumulator.py b/python/tests/test_trajectory_accumulator.py index 07c3216..fba3f7f 100644 --- a/python/tests/test_trajectory_accumulator.py +++ b/python/tests/test_trajectory_accumulator.py @@ -1,11 +1,9 @@ import numpy as np import pytest - +from conftest import free_port, wait_for_listen from echo import Server, TcpClient, TcpTransport from echo.trajectory_accumulator import TrajectoryAccumulator -from conftest import free_port, wait_for_listen - _EXAMPLE_SMALL = { "transition": {"obs": np.empty((2, 4), dtype=np.float32)}, } @@ -24,7 +22,7 @@ _EXAMPLE = { "transition": { "obs": np.empty((N, 12), dtype=np.float32), - "rew": np.empty((N,), dtype=np.float32), + "rew": np.empty((N,), dtype=np.float32), }, "summary": { "ret": np.empty((1,), dtype=np.float32), @@ -37,6 +35,84 @@ def test_counts_inferred_from_leading_dim(self): buf = TrajectoryAccumulator(_EXAMPLE) assert buf._counts == {"transition": N, "summary": 1} + def test_rejects_mismatched_leading_dim_in_buffered_timescale(self): + # Buffered timescales (no 0-d leaves) must have all leaves share + # leading dim. + bad = { + "timescale": { + "a": np.empty((4, 8), dtype=np.float32), # leading=4 + "b": np.empty((3,), dtype=np.float32), # leading=3 + }, + } + with pytest.raises(ValueError, match="same leading dimension"): + TrajectoryAccumulator(bad) + + def test_zero_d_leaf_marks_timescale_single_item(self): + # A 0-d leaf signals "single-item timescale": capacity is 1 and + # other leaves may have any per-item shape. + spec = { + "timescale": { + "obs": np.empty((4, 8), dtype=np.float32), # per-item shape + "flag": np.empty((), dtype=np.bool_), # 0-d marker + }, + } + buf = TrajectoryAccumulator(spec) + assert buf._counts == {"timescale": 1} + + def test_all_unit_leading_marks_timescale_single_item(self): + # Every leaf with shape (1, ...) → single-item, no 0-d marker needed. + spec = { + "timescale": { + "obs": np.empty((1, 8), dtype=np.float32), + "head": np.empty((1, 4, 2), dtype=np.float32), + }, + } + buf = TrajectoryAccumulator(spec) + assert buf._counts == {"timescale": 1} + + def test_single_unit_leading_leaf_is_single_item(self): + spec = {"timescale": {"a": np.empty((1,), dtype=np.float32)}} + buf = TrajectoryAccumulator(spec) + assert buf._counts == {"timescale": 1} + + def test_mixed_unit_and_non_unit_leading_rejected(self): + # (1, *) and (5,) — neither all-unit nor matching leading-dim. Falls + # through to the buffered branch and fails the invariant. + bad = { + "timescale": { + "a": np.empty((1, 8), dtype=np.float32), # leading=1 + "b": np.empty((5,), dtype=np.float32), # leading=5 + }, + } + with pytest.raises(ValueError, match="same leading dimension"): + TrajectoryAccumulator(bad) + + def test_unit_leading_write_replaces_whole_leaf(self): + spec = {"timescale": {"obs": np.empty((1, 4), dtype=np.float32)}} + buf = TrajectoryAccumulator(spec) + buf.add("timescale", {"obs": np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)}) + np.testing.assert_array_equal(buf._tree["timescale"]["obs"], [[1.0, 2.0, 3.0, 4.0]]) + + def test_over_index_raises_clear_indexerror(self): + spec = { + "transition": { + "obs": np.empty((N, 12), dtype=np.float32), + "rew": np.empty((N,), dtype=np.float32), + }, + } + buf = TrajectoryAccumulator(spec) + item = { + "obs": np.zeros(12, dtype=np.float32), + "rew": np.zeros((), dtype=np.float32), + } + for _ in range(N): + buf.add("transition", item) + with pytest.raises( + IndexError, + match=rf"Timescale 'transition' has {N} slots, but you tried to add at index {N}", + ): + buf.add("transition", item) + class TestTrajectoryAccumulatorAdd: def test_add_writes_correct_values(self): @@ -64,10 +140,13 @@ def test_add_dtype_cast(self): def test_add_multiple_slots(self): buf = TrajectoryAccumulator(_EXAMPLE) for i in range(N): - buf.add("transition", { - "obs": np.full(12, float(i), dtype=np.float32), - "rew": np.array(float(i * 10), dtype=np.float32), - }) + buf.add( + "transition", + { + "obs": np.full(12, float(i), dtype=np.float32), + "rew": np.array(float(i * 10), dtype=np.float32), + }, + ) tree = buf._tree for i in range(N): @@ -96,10 +175,13 @@ def test_capacity_inferred_for_zero_d_timescale(self): def test_write_zero_d_leaves(self): buf = TrajectoryAccumulator(self._example()) - buf.add("episode", { - "ret": np.array(7.5, dtype=np.float32), - "gen": np.array(42, dtype=np.int32), - }) + buf.add( + "episode", + { + "ret": np.array(7.5, dtype=np.float32), + "gen": np.array(42, dtype=np.int32), + }, + ) tree = buf._tree assert tree["episode"]["ret"].shape == () assert tree["episode"]["gen"].shape == () @@ -109,14 +191,20 @@ def test_write_zero_d_leaves(self): def test_mixed_zero_d_and_nd_in_same_build(self): buf = TrajectoryAccumulator(self._example()) for i in range(3): - buf.add("step", { - "obs": np.full(4, float(i), dtype=np.float32), - "reward": np.array(float(i * 10), dtype=np.float32), - }) - buf.add("episode", { - "ret": np.array(99.0, dtype=np.float32), - "gen": np.array(5, dtype=np.int32), - }) + buf.add( + "step", + { + "obs": np.full(4, float(i), dtype=np.float32), + "reward": np.array(float(i * 10), dtype=np.float32), + }, + ) + buf.add( + "episode", + { + "ret": np.array(99.0, dtype=np.float32), + "gen": np.array(5, dtype=np.int32), + }, + ) tree = buf.build() assert tree["step"]["obs"].shape == (3, 4) @@ -132,24 +220,33 @@ def test_mixed_zero_d_and_nd_in_same_build(self): def test_zero_d_timescale_full_after_one_add(self): buf = TrajectoryAccumulator(self._example()) - buf.add("episode", { - "ret": np.array(1.0, dtype=np.float32), - "gen": np.array(1, dtype=np.int32), - }) + buf.add( + "episode", + { + "ret": np.array(1.0, dtype=np.float32), + "gen": np.array(1, dtype=np.int32), + }, + ) with pytest.raises(IndexError): - buf.add("episode", { - "ret": np.array(2.0, dtype=np.float32), - "gen": np.array(2, dtype=np.int32), - }) + buf.add( + "episode", + { + "ret": np.array(2.0, dtype=np.float32), + "gen": np.array(2, dtype=np.int32), + }, + ) class TestTrajectoryAccumulatorBuild: def _fill(self, buf): for i in range(N): - buf.add("transition", { - "obs": np.full(12, float(i), dtype=np.float32), - "rew": np.array(float(i), dtype=np.float32), - }) + buf.add( + "transition", + { + "obs": np.full(12, float(i), dtype=np.float32), + "rew": np.array(float(i), dtype=np.float32), + }, + ) buf.add("summary", {"ret": np.array([99.0], dtype=np.float32)}) def test_build_returns_dict(self):