Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/src/api/trajectory-accumulator.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,14 @@ client side before calling `client.send`. Useful when a single rollout step
produces several arrays at different rates (e.g. one transition per env
step, plus a single episode-level statistic per episode).

Each timescale is either:

- **Buffered** — every leaf shares the same leading dim `N` (`N > 1`);
that's the capacity, filled by `N` `add()` calls.
- **Single-item** — capacity 1; detected when every leaf is 1-d or has
at least one leaf that is 0-d. `add()` replaces the whole leaf.

See the [guide](../guides/trajectory-accumulator.md) for the rationale and
worked examples.

::: echo.trajectory_accumulator.TrajectoryAccumulator
70 changes: 46 additions & 24 deletions docs/src/guides/trajectory-accumulator.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,18 @@ samples.
## How it works

Construct it with a dict whose top-level keys are timescale names. Each
leaf's leading dimension is the number of `add()` calls expected before the
buffer is ready to send.
timescale is one of two kinds, inferred from the example pytree:

- **Buffered** — every leaf shares the same leading dim `N` (with `N > 1`).
That leading dim is the timescale's capacity: the buffer fills after `N`
`add()` calls and each call writes into `stored[s:s+1]`.

- **Single-item** — the timescale holds one trailing piece of context
(e.g. an episode return, a bootstrap step) rather than a buffer of
steps. Detected when at least one leaf is 0-d, *or* all leaves have
`shape[0] == 1`. Capacity is `1` and `add()` replaces the whole leaf,
so non-0-d leaves can carry any per-item shape (apart from the
optional leading 1).

```python
import numpy as np
Expand All @@ -27,12 +37,21 @@ from echo import TrajectoryAccumulator, TcpClient
T = 64 # rollout length

example = {
# Buffered timescale: leading dim T across every leaf.
"step": {
"obs": np.zeros((T, 4), dtype=np.float32),
"reward": np.zeros((T, 1), dtype=np.float32),
"reward": np.zeros((T,), dtype=np.float32),
},
# Single-item timescale: all leaves have shape (1, ...) — capacity 1,
# `add()` replaces the whole leaf.
"episode": {
"return": np.zeros((1,), dtype=np.float32),
"length": np.zeros((1,), dtype=np.float32)
},
# Single-item timescale: 0-d reward means add() replaces the whole leaf
"final_step": {
"obs": np.zeros((4,), dtype=np.float32),
"reward": np.zeros((), dtype=np.int32),
},
}

Expand All @@ -42,44 +61,47 @@ buf = TrajectoryAccumulator(example)

for _ in range(num_rollouts):
episode_return = 0.0
reward = 0.0
obs = env.reset()
for _ in range(T):
obs, reward = env.step(...)
buf.add("step", {"obs": obs, "reward": reward})
obs, reward, ... = env.step(...)
episode_return += float(reward)

buf.add("episode", {"return": np.array([episode_return], dtype=np.float32)})
buf.add("episode", {"return": np.array([episode_return]), "length": np.array([length])})
buf.add("final_step", {"obs": obs, "reward": reward})
client.send(buf.build())
```

## Mental model

`TrajectoryAccumulator` is just pre-allocated numpy arrays plus per-timescale slot
counters. `add()` writes into the next free slot for that timescale via
slice assignment; `build()` returns the filled pytree, flips the active
buffer (so the next round writes into a fresh one without allocating), and
resets the counters.
`TrajectoryAccumulator` is just pre-allocated numpy arrays plus per-timescale
slot counters. For buffered timescales, `add()` writes into the next free
slot via slice assignment (`stored[s:s+1] = incoming`); for single-item
timescales it replaces the whole leaf (`stored[...] = incoming`).
`build()` returns the filled pytree and resets the counters. The buffer is
reused — no allocation per rollout.

There's no network or Rust involvement; it's purely a way to amortise the
flatten-and-send cost across many environment steps. The pytree it returns
goes through the normal `client.send` path.

## Two buffers, no allocation per rollout

`TrajectoryAccumulator` double-buffers internally: two copies of the pytree, with
`build()` swapping the active one. So while one buffer is being serialised
and sent, the next rollout can start filling the other without any
allocation. This matters when rollouts are short relative to flatten +
network latency.
The tree returned by `build()` aliases the accumulator's internal buffers,
so the next `add()` will overwrite it. The usual `client.send(buf.build())`
pattern is safe because `send` is synchronous and copies the bytes before
returning — don't hold onto the returned tree across further `add()` calls.

## Common pitfalls

- **Leading-dimension mismatch within a timescale.** All leaves under one
timescale key must share the same first axis size. That's what defines
"how many `add` calls before the buffer is full". The constructor checks
this and raises.
- **Adding past capacity.** If you call `add` more than the leading
dimension allows, you get `IndexError`. Call `reset()` if you want to
abort a partial rollout.
- **Leading-dimension mismatch in a *buffered* timescale.** Inside a
buffered timescale, all leaves must share the same first axis size —
that's what defines "how many `add` calls before the buffer is full".
The constructor checks this and raises. If you actually meant a
single-item timescale, make one leaf 0-d or add a leading dim to all leaves.
- **Adding past capacity.** If you call `add` more times than the
timescale's capacity, you get `IndexError` with the timescale name,
capacity, and offending index. Call `reset()` to abort a partial
rollout.
- **Dict-only at the top level.** The top-level pytree must be a `dict`
with timescale names. Below that, leaves can be any pytree shape that
`optree` understands.
77 changes: 50 additions & 27 deletions python/echo/trajectory_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,56 @@


class TrajectoryAccumulator:
"""Multi-timescale accumulator: fixed-size pytree buffer with double-buffering.
"""Multi-timescale accumulator: fixed-size pytree buffer.

Per timescale, the example pytree determines how many ``add()`` calls
fit before the buffer is full:

* **Buffered timescale** — every leaf shares the same leading dim ``N``
(``N > 1``); the accumulator stores ``N`` per-add items into
``stored[s:s+1] = incoming`` slot-by-slot. ``N`` becomes the
timescale's capacity.

* **Single-item timescale** — the timescale holds one trailing piece of
context (e.g. a bootstrap step, an episode return) rather than a
buffer. Detected when at least one leaf is 0-d, or all leaves
have ``shape[0] == 1``. Capacity is ``1``; ``add()`` replaces the
whole leaf, so non-0-d leaves may have any per-item shape (apart
from the optional leading 1).

Args:
example: Dict with timescale names as top-level keys. The leading
dimension of each leaf array is the number of ``add()`` calls
expected before the buffer is ready to send.
example: Dict with timescale names as top-level keys. Each value is
a pytree whose leaves declare the per-timescale layout per the
rule above.
"""

def __init__(self, example: dict[str, Any]):
if not isinstance(example, dict):
raise TypeError("example must be a dict with timescale names as top-level keys")

self._counts: dict[str, int] = {}
self._single_item: dict[str, bool] = {}
for name, subtree in example.items():
leaves = optree.tree_leaves(subtree)
if not leaves:
raise ValueError(f"Timescale '{name}' has no array leaves")
leading = leaves[0].shape[0] if leaves[0].ndim > 0 else 1
if not all((leaf.shape[0] if leaf.ndim > 0 else 1) == leading for leaf in leaves):
raise ValueError(
f"All leaves in timescale '{name}' must share the same leading dimension"
)
self._counts[name] = leading

# Two copies of the pytree for double-buffering.
self._trees: list[dict[str, Any]] = [
{n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()},
{n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()},
]
self._active = 0

# Single-item: any 0-d leaf OR every leaf with leading dim 1.
if any(leaf.ndim == 0 for leaf in leaves) or all(leaf.shape[0] == 1 for leaf in leaves):
self._counts[name] = 1
self._single_item[name] = True
else:
leading = [leaf.shape[0] for leaf in leaves]
if not all(s == leading[0] for s in leading):
raise ValueError(
f"All leaves in buffered timescale '{name}' must share the same "
f"leading dimension (got {leading}); make any leaf 0-d or "
f"all shape (1, ...) to mark the timescale single-item instead"
)
self._counts[name] = leading[0]
self._single_item[name] = False

self._tree: dict[str, Any] = {n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()}
self._slot: dict[str, int] = {name: 0 for name in example}

def add(self, name: str, data: Any) -> None:
Expand All @@ -43,24 +63,27 @@ def add(self, name: str, data: Any) -> None:
raise KeyError(f"Unknown timescale '{name}'. Known: {list(self._counts)}")
s = self._slot[name]
if s >= self._counts[name]:
raise IndexError(
f"Timescale '{name}' is already full ({self._counts[name]} slots). "
"Call reset() or build() before adding more."
)
raise IndexError(f"Timescale '{name}' has {self._counts[name]} slots, but you tried to add at index {s}")

# Single-item: replace the whole leaf
# Buffered: write into the next slot of the leading dim.
key = Ellipsis if self._single_item[name] else slice(s, s + 1)

def _write_slot(stored, incoming):
stored[s:s + 1] = incoming
stored[key] = incoming
return stored

optree.tree_map_(_write_slot, self._trees[self._active][name], data)
optree.tree_map_(_write_slot, self._tree[name], data)
self._slot[name] += 1

def build(self) -> dict[str, Any]:
"""Return the filled pytree and flip the active buffer."""
tree = self._trees[self._active]
self._active = 1 - self._active
"""Return the filled pytree and reset slot counters.

The returned tree aliases internal buffers; callers must finish using
it (e.g. complete the synchronous send) before the next ``add()``.
"""
self._slot = {name: 0 for name in self._slot}
return tree
return self._tree

def reset(self) -> None:
"""Reset slot counters without sending (e.g. on episode abort)."""
Expand Down
Loading
Loading