Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions dimos/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,16 @@ def process_observable(
self,
observable: "Observable[Any]",
async_cb: Callable[[Any], Any],
on_drop: Callable[[], None] | None = None,
) -> "DisposableBase":
"""Subscribe `async_cb` (an async function) to `observable`, dispatching
each emitted value onto self._loop. Invocations are serialized through a
per-subscription dispatcher task with LATEST coalescing. The subscription
per-subscription dispatcher task with LATEST coalescing. `on_drop`, if
given, fires once per message dropped by that coalescing. The subscription
is registered for cleanup on stop()."""
if not inspect.iscoroutinefunction(async_cb):
raise TypeError("process_observable requires an `async def` callback")
on_msg, dispatcher_disp = self._make_async_dispatch(async_cb)
on_msg, dispatcher_disp = self._make_async_dispatch(async_cb, on_drop)
sub = observable.subscribe(on_msg)
return self.register_disposable(CompositeDisposable(sub, dispatcher_disp))

Expand Down Expand Up @@ -635,7 +637,9 @@ def _auto_bind_handlers(self) -> None:
self.process_observable(in_stream.pure_observable(), handler)

def _make_async_dispatch(
self, async_handler: Callable[[Any], Any]
self,
async_handler: Callable[[Any], Any],
on_drop: Callable[[], None] | None = None,
) -> tuple[Callable[[Any], None], "DisposableBase"]:
"""Build a sync callback that delivers `msg` into a single-slot LATEST
mailbox drained by a dedicated dispatcher task on `self._loop`.
Expand All @@ -645,7 +649,9 @@ def _make_async_dispatch(
awaits).
- If messages arrive faster than the handler can process them,
intermediate messages are dropped and only the most recent unprocessed
message is kept (LATEST policy).
message is kept (LATEST policy). `on_drop`, if given, is called once
per dropped message (on the loop thread) so callers that need every
message can surface the loss.
- The returned Disposable cancels the dispatcher task.
"""
loop = self._loop
Expand Down Expand Up @@ -685,6 +691,10 @@ def on_msg(msg: Any) -> None:
return

def _set() -> None:
# A slot that still holds an unconsumed value is about to be
# overwritten — that queued message is being dropped (LATEST).
if slot["has_value"] and on_drop is not None:
on_drop()
slot["value"] = msg
slot["has_value"] = True
event.set()
Expand Down
4 changes: 2 additions & 2 deletions dimos/hardware/sensors/lidar/fastlio2/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ class FastLio2Recorder(Recorder):
_last_odom_pose: Pose | None = None

@pose_setter_for("fastlio_odometry")
def _odom_pose(self, msg: Odometry) -> Pose | None:
async def _odom_pose(self, msg: Odometry) -> Pose | None:
pose = getattr(msg, "pose", None)
self._last_odom_pose = getattr(pose, "pose", None) if pose is not None else None
return self._last_odom_pose

@pose_setter_for("fastlio_lidar")
def _lidar_pose(self, msg: PointCloud2) -> Pose | None:
async def _lidar_pose(self, msg: PointCloud2) -> Pose | None:
# Most-recent odometry pose, stamped directly (no tf). None before the
# first odometry -> frame stored unposed, map-skipped.
return self._last_odom_pose
56 changes: 42 additions & 14 deletions dimos/memory2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from collections.abc import Callable
from collections.abc import Awaitable, Callable
import enum
import inspect
import os
Expand Down Expand Up @@ -274,15 +274,20 @@ class RecorderConfig(MemoryModuleConfig):
stream_remapping: dict[str, str] = Field(default_factory=dict)


PoseSetter = Callable[[Any], "Pose | None"]
PoseSetter = Callable[[Any], "Awaitable[Pose | None]"]


def pose_setter_for(*stream_names: str) -> Callable[[Any], Any]:
"""Mark a method ``(self, msg) -> Pose | None`` as the pose setter for the
given recorded stream(s). Streams without a setter fall back to the tf-based
``world <- frame_id`` lookup."""
"""Mark an ``async def`` method ``(self, msg) -> Pose | None`` as the pose
setter for the given recorded stream(s). Streams without a setter fall back
to the tf-based ``world <- frame_id`` lookup."""

def decorate(fn: Any) -> Any:
if not inspect.iscoroutinefunction(fn):
raise TypeError(
f"@pose_setter_for must decorate an `async def` method; "
f"{getattr(fn, '__qualname__', fn)} is not async"
)
fn._pose_setter_for = tuple(stream_names)
return fn

Expand All @@ -302,16 +307,20 @@ class MyRecorder(Recorder):

Each stream's pose defaults to a ``world <- frame_id`` tf lookup; decorate a
method with ``@pose_setter_for("stream")`` to source it elsewhere (e.g. from
an odometry stream)::
an odometry stream). Setters run on the module's event loop and may be
``async def``::

@pose_setter_for("lidar")
def _lidar_pose(self, msg):
async def _lidar_pose(self, msg):
return self._last_odom_pose
"""

config: RecorderConfig

_pose_setters: dict[str, Any] = {}
# Per-stream count of frames lost to the dispatcher's LATEST coalescing
# (sink slower than input). Populated lazily as drops happen.
_dropped_frames: dict[str, int] = {}

@rpc
def start(self) -> None:
Expand All @@ -324,6 +333,7 @@ def start(self) -> None:
return

self._pose_setters = self._collect_pose_setters()
self._dropped_frames = {}

# TODO: store reset API/logic is not implemented yet. This module
# shouldn't need to know about files (SqliteStore specific), and
Expand Down Expand Up @@ -368,12 +378,14 @@ def _port_to_stream(self, name: str, input_topic: In[Any], stream: Stream[Any])
already in world coords) fall back to ``config.default_frame_id`` —
so every observation gets a robot-pose anchor when tf is publishing.

Registers the subscription as a disposable on this module.
Each port is recorded by an async callback dispatched on the module's
event loop via :meth:`process_observable`, which serialises invocations
and registers the subscription for cleanup on stop().
"""

def on_msg(msg: Any) -> None:
async def on_msg(msg: Any) -> None:
ts = self._resolve_ts(name, msg)
pose = self._resolve_pose(name, msg, ts)
pose = await self._resolve_pose(name, msg, ts)
if not pose:
logger.warning(
"[%s] No pose for time %s (msg ts: %s), storing without pose",
Expand All @@ -383,7 +395,23 @@ def on_msg(msg: Any) -> None:
)
stream.append(msg, ts=ts, pose=pose)

self.register_disposable(Disposable(input_topic.subscribe(on_msg)))
self.process_observable(
input_topic.pure_observable(), on_msg, on_drop=lambda: self._on_frame_dropped(name)
)

def _on_frame_dropped(self, name: str) -> None:
"""A frame for *name* was dropped because the sink couldn't keep up with
the input rate (dispatcher LATEST coalescing). Count it and warn — once,
then on each power-of-ten — so silent data loss is visible without
flooding the log."""
count = self._dropped_frames.get(name, 0) + 1
self._dropped_frames[name] = count
if count == 1 or count % 1000 == 0:
logger.warning(
"[%s] Recorder dropped %d frame(s) — sink slower than input; recording is lossy",
name,
count,
)

def _prepare_streams(self) -> None:
"""On APPEND, drop the streams this recorder is about to (re)write — the
Expand All @@ -401,13 +429,13 @@ def _resolve_ts(self, name: str, msg: Any) -> float:
"""Timestamp to record *msg* at. Override to re-base onto another clock."""
return getattr(msg, "ts", None) or time.time()

def _resolve_pose(self, name: str, msg: Any, ts: float) -> Pose | None:
"""Pose to anchor *msg* with. Dispatches to the stream's
async def _resolve_pose(self, name: str, msg: Any, ts: float) -> Pose | None:
"""Pose to anchor *msg* with. Dispatches to the stream's (async)
``@pose_setter_for`` if one is defined, else falls back to a
``world <- frame_id`` tf lookup."""
setter = self._pose_setters.get(name)
if setter is not None:
return cast("Pose | None", setter(msg))
return cast("Pose | None", await setter(msg))
frame_id = getattr(msg, "frame_id", None) or self.config.default_frame_id
transform = self.tf.get(
self.config.root_frame, frame_id, time_point=ts, time_tolerance=self.config.tf_tolerance
Expand Down
10 changes: 7 additions & 3 deletions dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,17 @@ class Go2Memory(Recorder):
_last_odom_pose: Pose | None = None

@pose_setter_for("odom")
def _odom_pose(self, msg: PoseStamped) -> Pose | None:
async def _odom_pose(self, msg: PoseStamped) -> Pose | None:
self._last_odom_pose = msg
return self._last_odom_pose

@pose_setter_for("lidar")
def _lidar_pose(self, msg: PointCloud2) -> Pose | None:
return self._last_odom_pose # should always exist (odom alwyas wins the race)
async def _lidar_pose(self, msg: PointCloud2) -> Pose | None:
# go2 lidar (currently) is in world-frame
# so it doesn't make sense to register lidar at the odom pose
# but we do it anyways because map.py (for now) requires it
# TODO: fix map.py to use a transform frame
return getattr(self, "_last_odom_pose", None)


unitree_go2_markers = (
Expand Down
Loading