diff --git a/AGENTS.md b/AGENTS.md index 4da2b37404..b2b83d7d67 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -44,7 +44,7 @@ dimos restart # stop + re-run with same original args | `xarm-perception-sim-agent` | xArm | sim | GPT-4o | — | Manipulation + perception + agent, sim | | `xarm7-planner-coordinator` | xArm7 | real | — | — | Trajectory planner coordinator | | `teleop-quest-xarm7` | xArm7 | real | — | — | Quest VR teleop | -| `dual-xarm6-planner` | xArm6×2 | real | — | — | Dual-arm motion planner | +| `dual-xarm6-planner-coordinator` | xArm6×2 | real | — | — | Dual-arm motion planner + coordinator | Run `dimos list` for the full list. diff --git a/dimos/control/README.md b/dimos/control/README.md index 858af0b6c4..56694a2c38 100644 --- a/dimos/control/README.md +++ b/dimos/control/README.md @@ -27,7 +27,6 @@ Centralized control system for multi-arm robots with per-joint arbitration. ```bash # Terminal 1: Run coordinator dimos run coordinator-mock # Single 7-DOF mock arm -dimos run coordinator-dual-mock # Dual arms (7+6 DOF) dimos run coordinator-piper-xarm # Real hardware # Terminal 2: Control via CLI diff --git a/dimos/control/blueprints/test_dual.py b/dimos/control/blueprints/test_dual.py new file mode 100644 index 0000000000..1dc855b7d7 --- /dev/null +++ b/dimos/control/blueprints/test_dual.py @@ -0,0 +1,79 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for dual-arm control blueprints.""" + +from dimos.control.coordinator import ControlCoordinator +from dimos.core.coordination.blueprints import Blueprint +from dimos.manipulation.manipulation_module import ManipulationModule +from dimos.robot.manipulators.xarm.blueprints.basic import coordinator_dual_xarm + + +def _dual_xarm6_planner_coordinator() -> Blueprint: + from dimos.robot.manipulators.xarm.blueprints.basic import dual_xarm6_planner_coordinator + + return dual_xarm6_planner_coordinator + + +def _coordinator_task_names(blueprint) -> list[str]: + atom = next(atom for atom in blueprint.blueprints if atom.module is ControlCoordinator) + return [task.name for task in atom.kwargs["tasks"]] + + +def _coordinator_tasks(blueprint): + atom = next(atom for atom in blueprint.blueprints if atom.module is ControlCoordinator) + return atom.kwargs["tasks"] + + +def _manipulation_robots(blueprint): + atom = next(atom for atom in blueprint.blueprints if atom.module is ManipulationModule) + return atom.kwargs["robots"] + + +def _manipulation_visualization(blueprint): + atom = next(atom for atom in blueprint.blueprints if atom.module is ManipulationModule) + return atom.kwargs["visualization"] + + +def test_dual_xarm6_planner_coordinator_blueprint_has_planner_and_coordinator() -> None: + dual_xarm6_planner_coordinator = _dual_xarm6_planner_coordinator() + modules = [atom.module for atom in dual_xarm6_planner_coordinator.blueprints] + + assert ManipulationModule in modules + assert ControlCoordinator in modules + + +def test_dual_xarm6_planner_coordinator_uses_viser_execution_ui() -> None: + dual_xarm6_planner_coordinator = _dual_xarm6_planner_coordinator() + visualization = _manipulation_visualization(dual_xarm6_planner_coordinator) + + assert visualization == {"backend": "viser", "allow_plan_execute": True} + + +def test_dual_xarm6_planner_coordinator_task_names_match_robot_defaults() -> None: + dual_xarm6_planner_coordinator = _dual_xarm6_planner_coordinator() + + for robot in _manipulation_robots(dual_xarm6_planner_coordinator): + assert robot.coordinator_task_name == f"traj_{robot.name}" + assert _coordinator_task_names(dual_xarm6_planner_coordinator) == [ + "traj_left_arm", + "traj_right_arm", + ] + + +def test_dual_coordinator_xarm_task_names_match_manipulation_robot_defaults() -> None: + assert _coordinator_task_names(coordinator_dual_xarm) == [ + "traj_left", + "traj_right", + ] diff --git a/dimos/control/tasks/trajectory_task/trajectory_task.py b/dimos/control/tasks/trajectory_task/trajectory_task.py index a3eb23dec3..f26ca72ec6 100644 --- a/dimos/control/tasks/trajectory_task/trajectory_task.py +++ b/dimos/control/tasks/trajectory_task/trajectory_task.py @@ -191,6 +191,23 @@ def execute(self, trajectory: JointTrajectory) -> bool: logger.warning(f"Empty trajectory for {self._name}") return False + if trajectory.joint_names and trajectory.joint_names != self._joint_names_list: + logger.warning( + f"Joint name mismatch for {self._name}: " + f"expected={self._joint_names_list}, received={trajectory.joint_names}" + ) + return False + + if not trajectory.joint_names: + expected_joint_count = len(self._joint_names_list) + for point in trajectory.points: + if len(point.positions) != expected_joint_count: + logger.warning( + f"Trajectory point dimension mismatch for {self._name}: " + f"expected={expected_joint_count}, received={len(point.positions)}" + ) + return False + # Preempt any active trajectory if self._state == TrajectoryState.EXECUTING: logger.info(f"Preempting active trajectory on {self._name}") diff --git a/dimos/control/test_control.py b/dimos/control/test_control.py index ae6bc1e9de..13f0579cb3 100644 --- a/dimos/control/test_control.py +++ b/dimos/control/test_control.py @@ -274,6 +274,29 @@ def test_execute_trajectory(self, trajectory_task, simple_trajectory): assert trajectory_task.is_active() assert trajectory_task.get_state() == TrajectoryState.EXECUTING + def test_execute_rejects_mismatched_joint_names(self, trajectory_task): + trajectory = JointTrajectory( + joint_names=["other/joint1", "other/joint2", "other/joint3"], + points=[ + TrajectoryPoint( + positions=[0.0, 0.0, 0.0], + velocities=[0.0, 0.0, 0.0], + time_from_start=0.0, + ), + TrajectoryPoint( + positions=[1.0, 0.5, 0.25], + velocities=[0.0, 0.0, 0.0], + time_from_start=1.0, + ), + ], + ) + + result = trajectory_task.execute(trajectory) + + assert result is False + assert not trajectory_task.is_active() + assert trajectory_task.get_state() == TrajectoryState.IDLE + def test_compute_during_trajectory(self, trajectory_task, simple_trajectory, coordinator_state): t_start = time.perf_counter() trajectory_task.execute(simple_trajectory) @@ -530,6 +553,28 @@ def test_tick_loop_calls_compute(self, mock_adapter): assert mock_task.compute.call_count > 0 + def test_write_all_hardware_logs_rejected_command(self, mocker): + hardware = {"arm": MagicMock()} + hardware["arm"].write_command.return_value = False + log_error = mocker.patch("dimos.control.tick_loop.logger.error") + tick_loop = TickLoop( + tick_rate=100.0, + hardware=hardware, + hardware_lock=threading.Lock(), + tasks={}, + task_lock=threading.Lock(), + joint_to_hardware={}, + ) + + tick_loop._write_all_hardware({"arm": ({"arm/joint1": 0.5}, ControlMode.SERVO_POSITION)}) + + hardware["arm"].write_command.assert_called_once_with( + {"arm/joint1": 0.5}, ControlMode.SERVO_POSITION + ) + log_error.assert_called_once_with( + "Hardware %s rejected %d %s command(s)", "arm", 1, "SERVO_POSITION" + ) + class TestIntegration: def test_full_trajectory_execution(self, mock_adapter): diff --git a/dimos/control/tick_loop.py b/dimos/control/tick_loop.py index 975dfa9333..6184591116 100644 --- a/dimos/control/tick_loop.py +++ b/dimos/control/tick_loop.py @@ -397,7 +397,13 @@ def _write_all_hardware( for hw_id, (positions, mode) in hw_commands.items(): if hw_id in self._hardware: try: - self._hardware[hw_id].write_command(positions, mode) + if not self._hardware[hw_id].write_command(positions, mode): + logger.error( + "Hardware %s rejected %d %s command(s)", + hw_id, + len(positions), + mode.name, + ) except Exception as e: logger.error(f"Failed to write to {hw_id}: {e}") diff --git a/dimos/e2e_tests/test_manipulation_planning_groups.py b/dimos/e2e_tests/test_manipulation_planning_groups.py new file mode 100644 index 0000000000..7ccdf3f4ce --- /dev/null +++ b/dimos/e2e_tests/test_manipulation_planning_groups.py @@ -0,0 +1,216 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Large E2E tests for manipulation planning groups with a coordinator. + +These tests launch a real ManipulationModule + ControlCoordinator blueprint and +exercise the public planning RPCs over LCM, matching the self-hosted large-test +style used by the navigation stack. +""" + +from __future__ import annotations + +from collections.abc import Callable +import time +from typing import Any + +import pytest + +from dimos.control.coordinator import ControlCoordinator +from dimos.core.rpc_client import RPCClient +from dimos.e2e_tests.dimos_cli_call import DimosCliCall +from dimos.e2e_tests.lcm_spy import LcmSpy +from dimos.manipulation.manipulation_module import ManipulationModule +from dimos.manipulation.planning.groups.models import PlanningGroup +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState + +pytestmark = [pytest.mark.self_hosted_large] + +JOINT_STATE_TOPIC = "/coordinator/joint_state#sensor_msgs.JointState" +BLUEPRINT = "openarm-mock-planner-coordinator" + + +def _wait_for_robot_info( + client: RPCClient, + robot_name: str, + *, + timeout: float = 120.0, +) -> dict[str, Any]: + deadline = time.time() + timeout + last_error: BaseException | None = None + while time.time() < deadline: + try: + info = client.get_robot_info(robot_name) + if info and info.get("planning_groups"): + return info + except BaseException as exc: + last_error = exc + time.sleep(0.5) + raise TimeoutError(f"Timed out waiting for {robot_name!r} robot info") from last_error + + +def _wait_for_trajectory_completion( + client: RPCClient, + robot_name: str, + *, + timeout: float = 10.0, +) -> None: + deadline = time.time() + timeout + last_status: dict[str, Any] | None = None + while time.time() < deadline: + last_status = client.get_trajectory_status(robot_name) + if last_status is not None and last_status.get("state") == TrajectoryState.COMPLETED: + return + time.sleep(0.1) + raise TimeoutError(f"{robot_name!r} trajectory did not complete; last={last_status}") + + +def _wait_for_manipulation_state( + client: RPCClient, + state_name: str, + *, + timeout: float = 10.0, +) -> None: + deadline = time.time() + timeout + last_state: str | None = None + while time.time() < deadline: + last_state = client.get_state() + if last_state == state_name: + return + time.sleep(0.1) + raise TimeoutError(f"ManipulationModule did not reach {state_name}; last={last_state}") + + +def _wait_for_current_joints( + client: RPCClient, + robot_names: tuple[str, ...], + *, + timeout: float = 10.0, +) -> None: + deadline = time.time() + timeout + missing = robot_names + while time.time() < deadline: + missing = tuple( + robot_name + for robot_name in robot_names + if client.get_current_joints(robot_name) is None + ) + if not missing: + return + time.sleep(0.1) + raise TimeoutError(f"Timed out waiting for current joints from {missing}") + + +def _prepare_for_planning(client: RPCClient, robot_names: tuple[str, ...]) -> None: + client.reset() + _wait_for_manipulation_state(client, "IDLE") + _wait_for_current_joints(client, robot_names) + # Robot info and joint-state topics can become available just before the + # manipulation module finishes finalizing world monitors. Require a stable + # ready state after joint state is flowing to avoid command-readiness flakes. + time.sleep(0.25) + _wait_for_manipulation_state(client, "IDLE") + + +def _planning_group_id(info: dict[str, Any]) -> str: + groups = info["planning_groups"] + assert len(groups) == 1 + group = groups[0] + if isinstance(group, PlanningGroup): + return group.id + group_id = group["id"] + assert isinstance(group_id, str) + return group_id + + +def _offset_target(client: RPCClient, robot_name: str, delta: float) -> JointState: + current = client.get_current_joints(robot_name) + assert current is not None + return JointState(position=[position + delta for position in current]) + + +def _start_openarm_mock_planner( + start_blueprint: Callable[..., DimosCliCall], lcm_spy: LcmSpy +) -> None: + lcm_spy.save_topic(JOINT_STATE_TOPIC) + start_blueprint(BLUEPRINT) + lcm_spy.wait_for_saved_topic(JOINT_STATE_TOPIC, timeout=120.0) + + +def test_single_arm_plans_and_executes_through_control_coordinator( + lcm_spy: LcmSpy, + start_blueprint: Callable[..., DimosCliCall], +) -> None: + """Plan with one arm and execute through its trajectory task.""" + _start_openarm_mock_planner(start_blueprint, lcm_spy) + + client = RPCClient(None, ManipulationModule) + coordinator_client = RPCClient(None, ControlCoordinator) + try: + left_info = _wait_for_robot_info(client, "left_arm") + left_id = _planning_group_id(left_info) + + tasks = coordinator_client.list_tasks() + assert left_info["coordinator_task_name"] in tasks + + _prepare_for_planning(client, ("left_arm",)) + + planned = client.plan_to_joint_targets({left_id: _offset_target(client, "left_arm", 0.02)}) + assert planned, client.get_error() + assert client.has_planned_path() + assert client.execute_plan() + + _wait_for_trajectory_completion(client, "left_arm") + finally: + coordinator_client.stop_rpc_client() + client.stop_rpc_client() + + +def test_dual_arm_plans_and_dispatches_both_arms_through_control_coordinator( + lcm_spy: LcmSpy, + start_blueprint: Callable[..., DimosCliCall], +) -> None: + """Plan one generated plan over both arms and dispatch both JTC tasks.""" + _start_openarm_mock_planner(start_blueprint, lcm_spy) + + client = RPCClient(None, ManipulationModule) + coordinator_client = RPCClient(None, ControlCoordinator) + try: + left_info = _wait_for_robot_info(client, "left_arm") + right_info = _wait_for_robot_info(client, "right_arm") + left_id = _planning_group_id(left_info) + right_id = _planning_group_id(right_info) + + tasks = coordinator_client.list_tasks() + assert left_info["coordinator_task_name"] in tasks + assert right_info["coordinator_task_name"] in tasks + + _prepare_for_planning(client, ("left_arm", "right_arm")) + + planned = client.plan_to_joint_targets( + { + left_id: _offset_target(client, "left_arm", 0.02), + right_id: _offset_target(client, "right_arm", -0.02), + } + ) + assert planned, client.get_error() + assert client.has_planned_path() + assert client.execute_plan() + + _wait_for_trajectory_completion(client, "left_arm") + _wait_for_trajectory_completion(client, "right_arm") + finally: + coordinator_client.stop_rpc_client() + client.stop_rpc_client() diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index 0dd9edbd9e..25844db857 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -24,6 +24,7 @@ ) from dimos.robot.manipulators.xarm.blueprints.basic import ( dual_xarm6_planner as dual_xarm6_planner, + dual_xarm6_planner_coordinator as dual_xarm6_planner_coordinator, xarm6_planner_only as xarm6_planner_only, xarm7_planner_coordinator as xarm7_planner_coordinator, ) diff --git a/dimos/manipulation/control/coordinator_client.py b/dimos/manipulation/control/coordinator_client.py index ed552e0846..6fe78d13cc 100644 --- a/dimos/manipulation/control/coordinator_client.py +++ b/dimos/manipulation/control/coordinator_client.py @@ -24,12 +24,10 @@ Usage: # Terminal 1: Start the coordinator dimos run coordinator-mock # Single arm - dimos run coordinator-dual-mock # Dual arm # Terminal 2: Run this client python -m dimos.manipulation.control.coordinator_client - python -m dimos.manipulation.control.coordinator_client --task traj_left - python -m dimos.manipulation.control.coordinator_client --task traj_right + python -m dimos.manipulation.control.coordinator_client --task traj_arm How it works: 1. Connects to ControlCoordinator via LCM RPC diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index 0d4dfff226..0bcd4c0db9 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -15,7 +15,7 @@ """Manipulation Module - Motion planning with ControlCoordinator execution. Base module providing core manipulation infrastructure: -- @rpc: Low-level building blocks (plan_to_pose, plan_to_joints, preview_path, execute) +- @rpc: Low-level building blocks (plan_to_pose, plan_to_joints, preview_plan, execute) - @skill (short-horizon): Single-step actions (move_to_pose, open_gripper, go_home, go_init) Subclass PickAndPlaceModule (pick_and_place_module.py) adds perception integration @@ -24,12 +24,13 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from enum import Enum import threading import time +import traceback from typing import TYPE_CHECKING, Any, TypeAlias -import numpy as np from pydantic import Field from dimos.agents.annotation import skill @@ -39,6 +40,17 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In from dimos.manipulation.planning.factory import create_planning_specs, create_world +from dimos.manipulation.planning.groups.identifiers import ( + assert_global_joint_names, + assert_local_joint_names, + make_global_joint_names, +) +from dimos.manipulation.planning.groups.joints import ( + filter_joint_state_to_selected_joints, + joint_target_to_global_names, + planning_group_id_from_selector, +) +from dimos.manipulation.planning.groups.models import PlanningGroup from dimos.manipulation.planning.kinematics.config import ( ManipulationKinematicsConfig, PinkKinematicsConfig, @@ -47,9 +59,13 @@ from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.manipulation.planning.spec.enums import IKStatus, ObstacleType from dimos.manipulation.planning.spec.models import ( + CollisionCheckResult, + ForwardKinematicsResult, + GeneratedPlan, IKResult, - JointPath, Obstacle, + PlanningGroupID, + PlanningResult, RobotName, WorldRobotID, ) @@ -63,9 +79,11 @@ NoManipulationVisualizationConfig, ) from dimos.manipulation.visualization.factory import create_manipulation_visualization -from dimos.manipulation.visualization.types import TargetEvaluation +from dimos.manipulation.visualization.types import RobotInfo from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.JointState import JointState from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory @@ -83,12 +101,6 @@ RobotRegistry: TypeAlias = dict[RobotName, RobotEntry] """Maps robot_name -> RobotEntry""" -PlannedPaths: TypeAlias = dict[RobotName, JointPath] -"""Maps robot_name -> planned joint path""" - -PlannedTrajectories: TypeAlias = dict[RobotName, JointTrajectory] -"""Maps robot_name -> planned trajectory""" - class ManipulationState(Enum): """State machine for manipulation module.""" @@ -116,6 +128,7 @@ class ManipulationModuleConfig(ModuleConfig): # to prevent the planner from routing trajectories below this height. # Set to None to disable. floor_z: float | None = None + coordinator_rpc_timeout: float = 3.0 class ManipulationModule(Module): @@ -149,9 +162,8 @@ def __init__(self, **kwargs: Any) -> None: # Robot registry: maps robot_name -> (world_robot_id, config, trajectory_gen) self._robots: RobotRegistry = {} - # Stored path for plan/preview/execute workflow (per robot) - self._planned_paths: PlannedPaths = {} - self._planned_trajectories: PlannedTrajectories = {} + # Stored generated plan for preview/execute workflow. + self._last_plan: GeneratedPlan | None = None # Coordinator integration (lazy initialized) self._coordinator_client: RPCClient | None = None @@ -285,27 +297,31 @@ def _get_robot( def _on_joint_state(self, msg: JointState) -> None: """Callback when joint state received from driver. - Splits the aggregated JointState by robot using each robot's - coordinator joint names, then routes to the correct monitor. + Splits the aggregated global JointState by robot, then routes local + robot-scoped states to the correct monitor. """ try: if self._world_monitor is None: return + assert_global_joint_names(msg.name) + # Build name → index map once for the whole message name_to_idx = {name: i for i, name in enumerate(msg.name)} for robot_name, (robot_id, config, _) in self._robots.items(): - coord_names = config.get_coordinator_joint_names() - indices = [name_to_idx.get(cn) for cn in coord_names] + global_names = make_global_joint_names(robot_name, config.joint_names) + indices = [name_to_idx.get(global_name) for global_name in global_names] if any(idx is None for idx in indices): missing = [ - cn for cn, idx in zip(coord_names, indices, strict=False) if idx is None + name + for name, idx in zip(global_names, indices, strict=False) + if idx is None ] logger.warning(f"Skipping '{robot_name}': missing joints {missing}") continue - # Build per-robot sub-message (coordinator namespace) + # Build per-robot sub-message (local model namespace) sub_positions = [msg.position[idx] for idx in indices] # type: ignore[index] sub_velocities = ( [msg.velocity[idx] for idx in indices] # type: ignore[index] @@ -313,7 +329,7 @@ def _on_joint_state(self, msg: JointState) -> None: else [] ) sub_msg = JointState( - name=list(coord_names), + name=list(config.joint_names), position=sub_positions, velocity=sub_velocities, ) @@ -331,14 +347,10 @@ def _on_joint_state(self, msg: JointState) -> None: except Exception as e: logger.error(f"Exception in _on_joint_state: {e}") - import traceback - logger.error(traceback.format_exc()) def _tf_publish_loop(self) -> None: """Publish TF transforms at 10Hz for EE and extra links.""" - from dimos.msgs.geometry_msgs.Transform import Transform - period = 0.1 # 10Hz while not self._tf_stop_event.is_set(): try: @@ -346,10 +358,26 @@ def _tf_publish_loop(self) -> None: break transforms: list[Transform] = [] for robot_id, config, _ in self._robots.values(): - # Publish world → EE - ee_pose = self._world_monitor.get_ee_pose(robot_id) - if ee_pose is not None: - ee_tf = Transform.from_pose(config.end_effector_link, ee_pose) + # Publish world → primary planning-group target frame. + # Fall back to robot-scoped EE only for compatibility configs. + # TODO: Publish one TF per pose-targetable group, or expose the + # backend's full robot TF tree, once consumers stop assuming a + # single robot-scoped end-effector frame. + target_frame = config.end_effector_link + ee_pose: PoseStamped | None + pose_group_id = self._primary_pose_group_id_for_robot(config.name) + if pose_group_id is not None: + pose_group = self._world_monitor.planning_groups.get(pose_group_id) + target_frame = pose_group.tip_link + ee_pose = self._world_monitor.get_group_ee_pose(pose_group_id) + else: + ee_pose = ( + self._world_monitor.get_link_pose(robot_id, target_frame) + if target_frame is not None + else None + ) + if ee_pose is not None and target_frame is not None: + ee_tf = Transform.from_pose(target_frame, ee_pose) ee_tf.frame_id = "world" transforms.append(ee_tf) @@ -436,7 +464,13 @@ def get_ee_pose(self, robot_name: RobotName | None = None) -> Pose | None: robot_name: Robot to query (required if multiple robots configured) """ if (robot := self._get_robot(robot_name)) and self._world_monitor: - return self._world_monitor.get_ee_pose(robot[1], joint_state=None) + _, robot_id, config, _ = robot + pose_group_id = self._primary_pose_group_id_for_robot(config.name) + if pose_group_id is not None: + return self._world_monitor.get_group_ee_pose(pose_group_id, joint_state=None) + if config.end_effector_link is None: + return None + return self._world_monitor.get_link_pose(robot_id, config.end_effector_link) return None @rpc @@ -453,26 +487,18 @@ def is_collision_free(self, joints: list[float], robot_name: RobotName | None = return self._world_monitor.is_state_valid(robot_id, joint_state) return False - def _begin_planning( - self, robot_name: RobotName | None = None - ) -> tuple[RobotName, WorldRobotID] | None: - """Check state and begin planning. Returns (robot_name, robot_id) or None. - - Args: - robot_name: Robot to plan for (required if multiple robots configured) - """ + def _begin_planning(self) -> bool: + """Check state and begin planning for the selected planning groups.""" if self._world_monitor is None: logger.error("Planning not initialized") - return None - if (robot := self._get_robot(robot_name)) is None: - return None + return False with self._lock: if self._state not in (ManipulationState.IDLE, ManipulationState.COMPLETED): logger.warning(f"Cannot plan: state is {self._state.name}") - return None + return False self._planning_epoch += 1 self._state = ManipulationState.PLANNING - return robot[0], robot[1] + return True def _fail(self, msg: str) -> bool: """Set FAULT state with error message.""" @@ -481,224 +507,456 @@ def _fail(self, msg: str) -> bool: self._error_message = msg return False - def _dismiss_preview(self, robot_id: WorldRobotID) -> None: + def _default_group_id_for_robot(self, robot_name: RobotName) -> PlanningGroupID | None: + """Return the generated fallback group used by robot-scoped wrappers.""" + assert self._world_monitor is not None + group_id = self._world_monitor.planning_groups.default_group_id_for_robot(robot_name) + if group_id is not None: + return group_id + logger.error( + "Robot '%s' has no generated default planning group; use explicit group APIs", + robot_name, + ) + return None + + def _primary_pose_group_id_for_robot(self, robot_name: RobotName) -> PlanningGroupID | None: + """Return the first pose-targetable group for robot-scoped compatibility paths.""" + assert self._world_monitor is not None + return self._world_monitor.planning_groups.primary_pose_group_id_for_robot(robot_name) + + def _selected_joint_state(self, group_ids: tuple[PlanningGroupID, ...]) -> JointState | None: + """Collect current state for exactly the selected global joints.""" + assert self._world_monitor is not None + selection = self._world_monitor.planning_groups.select(group_ids) + current = self._world_monitor.current_global_joint_state() + if isinstance(current, JointState): + try: + return filter_joint_state_to_selected_joints(current, selection.joint_names) + except ValueError as exc: + logger.error("Current state missing selected joints: %s", exc) + return None + if current is None: + logger.error("No fresh planning-world joint state") + return None + logger.error("Invalid planning-world joint state") + return None + + def _joint_target_to_global_names( + self, group_id: PlanningGroupID, target: JointState + ) -> JointState | None: + """Convert a group joint target to global joint names in group order.""" + assert self._world_monitor is not None + group = self._world_monitor.planning_groups.get(group_id) + try: + return joint_target_to_global_names(group, target) + except ValueError as exc: + logger.error(str(exc)) + return None + + def _affected_robot_names(self, plan: GeneratedPlan) -> list[RobotName]: + """Get stable robot names affected by a generated plan.""" + assert self._world_monitor is not None + return list(self._world_monitor.planning_groups.select(plan.group_ids).robot_names) + + def _store_generated_plan( + self, group_ids: tuple[PlanningGroupID, ...], result: PlanningResult + ) -> None: + """Store the canonical generated plan.""" + self._last_plan = GeneratedPlan( + group_ids=group_ids, + path=result.path, + status=result.status, + planning_time=result.planning_time, + path_length=result.path_length, + iterations=result.iterations, + message=result.message, + ) + + def _plan_selected_path( + self, group_ids: tuple[PlanningGroupID, ...], start: JointState, goal: JointState + ) -> bool: + """Plan over an explicit planning group selection and store the result.""" + assert self._world_monitor and self._planner + result = self._planner.plan_selected_joint_path( + world=self._world_monitor.world, + selection=self._world_monitor.planning_groups.select(group_ids), + start=start, + goal=goal, + timeout=self.config.planning_timeout, + ) + if not result.is_success(): + return self._fail(f"Planning failed: {result.status.name}") + + path_joints = list(result.path[-1].name) if result.path else [] + logger.info( + "Path: %d waypoints, groups=%s, joints=%s", + len(result.path), + group_ids, + path_joints, + ) + self._store_generated_plan(group_ids, result) + self._state = ManipulationState.COMPLETED + return True + + def _dismiss_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: """Hide the preview ghost if the world supports it.""" if self._world_monitor is None: return - self._world_monitor.hide_preview(robot_id) + self._world_monitor.hide_preview(group_ids) self._world_monitor.publish_visualization() - def _solve_ik_for_pose( + @rpc + def check_collision( self, - robot_id: WorldRobotID, - pose: Pose, - seed: JointState, - check_collision: bool, - ) -> IKResult: - """Run the configured kinematics backend for a world-frame pose.""" - assert self._world_monitor and self._kinematics + target_joints: JointState, + max_age: float = 1.0, + ) -> CollisionCheckResult: + """Check a partial global joint target against the planning world.""" + if self._world_monitor is None: + return CollisionCheckResult( + status="UNAVAILABLE", + collision_free=None, + message="Planning is not initialized", + ) + return self._world_monitor.check_collision(target_joints, max_age=max_age) - # Convert Pose to PoseStamped for the IK solver - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + def _planning_group_models(self) -> list[PlanningGroup]: + """Return all planning group models in stable registry order.""" + if self._world_monitor is None: + return [] + return list(self._world_monitor.planning_groups.list()) - target_pose = PoseStamped( - frame_id="world", - position=pose.position, - orientation=pose.orientation, + @rpc + def list_planning_groups(self) -> list[PlanningGroup]: + """Return all planning groups.""" + return self._planning_group_models() + + def get_current_joint_state(self, robot_name: RobotName) -> JointState | None: + """Return the named robot's current local joint state with names.""" + if self._world_monitor is None: + return None + robot_id = self.robot_id_for_name(robot_name) + if robot_id is None: + return None + return self._world_monitor.get_current_joint_state(robot_id) + + @rpc + def forward_kinematics( + self, + group_id: PlanningGroupID, + target_joints: JointState | None = None, + max_age: float = 1.0, + ) -> ForwardKinematicsResult: + """Compute the selected planning group's end-effector pose.""" + if self._world_monitor is None: + return ForwardKinematicsResult( + status="UNAVAILABLE", + pose=None, + message="Planning is not initialized", + ) + try: + group = self._world_monitor.planning_groups.get(group_id) + except KeyError as exc: + return ForwardKinematicsResult(status="INVALID", pose=None, message=str(exc)) + if not group.has_pose_target: + return ForwardKinematicsResult( + status="INVALID", + pose=None, + message=f"Planning group '{group_id}' has no pose target frame", + ) + robot = self._robots.get(group.robot_name) + if robot is None: + return ForwardKinematicsResult( + status="INVALID", + pose=None, + message=f"Robot '{group.robot_name}' is not registered", + ) + robot_id, config, _ = robot + + if target_joints is None: + monitor = self._world_monitor.get_state_monitor(robot_id) + if monitor is None or monitor.is_state_stale(max_age): + return ForwardKinematicsResult( + status="STALE_STATE", + pose=None, + message="Fresh monitored robot joint state is unavailable", + ) + joint_state = self._world_monitor.get_current_joint_state(robot_id) + if joint_state is None: + return ForwardKinematicsResult( + status="STALE_STATE", + pose=None, + message="Fresh monitored robot joint state is unavailable", + ) + else: + if len(target_joints.name) != len(target_joints.position): + return ForwardKinematicsResult( + status="INVALID", + pose=None, + message="FK target name and position lengths must match", + ) + if len(set(target_joints.name)) != len(target_joints.name): + return ForwardKinematicsResult( + status="INVALID", + pose=None, + message="FK target contains duplicate joint names", + ) + try: + assert_global_joint_names(target_joints.name) + except ValueError as exc: + return ForwardKinematicsResult(status="INVALID", pose=None, message=str(exc)) + positions_by_global_name = dict( + zip(target_joints.name, target_joints.position, strict=True) + ) + missing = [name for name in group.joint_names if name not in positions_by_global_name] + if missing: + return ForwardKinematicsResult( + status="INVALID", + pose=None, + message=f"FK target missing group joints: {missing}", + ) + current = self._world_monitor.get_current_joint_state(robot_id) + current_by_local_name = ( + dict(zip(current.name, current.position, strict=False)) + if current is not None + else {} + ) + positions: list[float] = [] + for local_name, global_name in zip( + config.joint_names, + make_global_joint_names(group.robot_name, config.joint_names), + strict=True, + ): + if global_name in positions_by_global_name: + positions.append(float(positions_by_global_name[global_name])) + else: + positions.append(float((current_by_local_name or {}).get(local_name, 0.0))) + joint_state = JointState(name=list(config.joint_names), position=positions) + + try: + pose = self._world_monitor.get_group_ee_pose(group_id, joint_state) + except Exception as exc: + return ForwardKinematicsResult( + status="UNAVAILABLE", + pose=None, + message=f"Forward kinematics failed: {exc}", + ) + return ForwardKinematicsResult( + status="VALID", pose=pose, message="Forward kinematics solved" ) - return self._kinematics.solve( + @rpc + def inverse_kinematics( + self, + pose_targets: Mapping[PlanningGroupID, PoseStamped], + auxiliary_group_ids: Sequence[PlanningGroupID] = (), + seed: JointState | None = None, + ) -> IKResult: + """Solve planning-group pose targets without collision filtering.""" + if self._kinematics is None or self._world_monitor is None: + return IKResult(status=IKStatus.NO_SOLUTION, message="Planning not initialized") + if not pose_targets: + return IKResult( + status=IKStatus.NO_SOLUTION, message="At least one pose target is required" + ) + group_ids = tuple(dict.fromkeys((*pose_targets.keys(), *auxiliary_group_ids))) + try: + target_groups = { + self._world_monitor.planning_groups.get(group_id): pose + for group_id, pose in pose_targets.items() + } + auxiliary_groups = tuple( + self._world_monitor.planning_groups.get(group_id) + for group_id in auxiliary_group_ids + ) + seed_state = seed or self._selected_joint_state(group_ids) + except (KeyError, ValueError) as exc: + return IKResult(status=IKStatus.NO_SOLUTION, message=str(exc)) + if seed_state is None: + return IKResult(status=IKStatus.NO_SOLUTION, message="No joint state") + return self._kinematics.solve_pose_targets( world=self._world_monitor.world, - robot_id=robot_id, - target_pose=target_pose, - seed=seed, - check_collision=check_collision, + pose_targets=target_groups, + auxiliary_groups=auxiliary_groups, + seed=seed_state, ) @rpc - def solve_ik( + def inverse_kinematics_single( self, pose: Pose, robot_name: RobotName | None = None, - check_collision: bool = True, seed: JointState | None = None, ) -> IKResult: - """Solve IK for a pose without planning a joint path. + """Solve IK for one robot's primary pose-targetable planning group. Args: pose: Target end-effector pose - robot_name: Robot to solve for (required if multiple robots configured) - check_collision: Whether to reject IK candidates in collision + robot_name: Robot to solve for (required if multiple robots configured). seed: Optional joint state to initialize local IK. Uses current state when omitted. """ - if self._kinematics is None or self._world_monitor is None: + if self._world_monitor is None: return IKResult(status=IKStatus.NO_SOLUTION, message="Planning not initialized") robot = self._get_robot(robot_name) if robot is None: return IKResult(status=IKStatus.NO_SOLUTION, message="Robot not found") + selected_robot_name, _, _, _ = robot + group_id = self._primary_pose_group_id_for_robot(selected_robot_name) + if group_id is None: + return IKResult( + status=IKStatus.NO_SOLUTION, message="No pose-targetable planning group" + ) + target_pose = PoseStamped( + frame_id="world", + position=pose.position, + orientation=pose.orientation, + ) + return self.inverse_kinematics({group_id: target_pose}, seed=seed) - with self._lock: - if self._state not in (ManipulationState.IDLE, ManipulationState.COMPLETED): - return IKResult( - status=IKStatus.NO_SOLUTION, - message=f"Cannot solve IK while state is {self._state.name}", - ) - self._state = ManipulationState.PLANNING - - _, robot_id, _, _ = robot - seed_state = seed or self._world_monitor.get_current_joint_state(robot_id) - if seed_state is None: - self._state = ManipulationState.IDLE - return IKResult(status=IKStatus.NO_SOLUTION, message="No joint state") - - result = self._solve_ik_for_pose(robot_id, pose, seed_state, check_collision) - self._state = ManipulationState.COMPLETED if result.is_success() else ManipulationState.IDLE - if result.is_success(): - logger.info(f"IK solved, error: {result.position_error:.4f}m") - return result + @rpc + def solve_ik( + self, + pose: Pose, + robot_name: RobotName | None = None, + seed: JointState | None = None, + ) -> IKResult: + """Compatibility wrapper for inverse_kinematics_single().""" + return self.inverse_kinematics_single(pose, robot_name=robot_name, seed=seed) @rpc def plan_to_pose(self, pose: Pose, robot_name: RobotName | None = None) -> bool: - """Plan motion to pose. Use preview_path() then execute(). + """Plan motion to pose. Use preview_plan() then execute(). Args: pose: Target end-effector pose robot_name: Robot to plan for (required if multiple robots configured) """ - if self._kinematics is None or (r := self._begin_planning(robot_name)) is None: + if self._kinematics is None or self._world_monitor is None: return False - robot_name, robot_id = r - planning_epoch = self._planning_epoch - assert self._world_monitor # guaranteed by _begin_planning + robot = self._get_robot(robot_name) + if robot is None: + return False + selected_robot_name, _, _, _ = robot + group_id = self._default_group_id_for_robot(selected_robot_name) + if group_id is None: + return False + return self.plan_to_pose_targets({group_id: pose}) - current = self._world_monitor.get_current_joint_state(robot_id) - if current is None: + @rpc + def plan_to_pose_targets( + self, + pose_targets: Mapping[PlanningGroupID | PlanningGroup, Pose], + auxiliary_groups: Sequence[PlanningGroupID | PlanningGroup] = (), + ) -> bool: + """Plan to one or more group pose targets with optional auxiliary groups.""" + if self._world_monitor is None or self._kinematics is None or self._planner is None: + return False + if not pose_targets: + logger.error("At least one pose target is required") + return False + + stamped_targets = { + planning_group_id_from_selector(group): PoseStamped( + frame_id="world", + position=pose.position, + orientation=pose.orientation, + ) + for group, pose in pose_targets.items() + } + auxiliary_ids = tuple(planning_group_id_from_selector(group) for group in auxiliary_groups) + group_ids = tuple(dict.fromkeys((*stamped_targets.keys(), *auxiliary_ids))) + if not self._begin_planning(): + return False + + try: + start = self._selected_joint_state(group_ids) + except Exception as exc: + return self._fail(f"Failed to resolve planning groups: {exc}") + if start is None: return self._fail("No joint state") - ik = self._solve_ik_for_pose(robot_id, pose, current, check_collision=True) + ik = self.inverse_kinematics( + pose_targets=stamped_targets, + auxiliary_group_ids=auxiliary_ids, + seed=start, + ) if not ik.is_success() or ik.joint_state is None: return self._fail(f"IK failed: {ik.status.name}") - - logger.info(f"IK solved, error: {ik.position_error:.4f}m") - return self._plan_path_only(robot_name, robot_id, ik.joint_state, planning_epoch) + return self._plan_selected_path(group_ids, start, ik.joint_state) @rpc def plan_to_joints(self, joints: JointState, robot_name: RobotName | None = None) -> bool: - """Plan motion to joint config. Use preview_path() then execute(). + """Plan motion to joint config. Use preview_plan() then execute(). Args: joints: Target joint state (names + positions) robot_name: Robot to plan for (required if multiple robots configured) """ - if (r := self._begin_planning(robot_name)) is None: + robot = self._get_robot(robot_name) + if robot is None: return False - robot_name, robot_id = r - planning_epoch = self._planning_epoch + robot_name, _, _, _ = robot logger.info(f"Planning to joints for {robot_name}: {[f'{j:.3f}' for j in joints.position]}") - return self._plan_path_only(robot_name, robot_id, joints, planning_epoch) + group_id = self._default_group_id_for_robot(robot_name) + if group_id is None: + return False + return self.plan_to_joint_targets({group_id: joints}) - def _plan_path_only( - self, - robot_name: RobotName, - robot_id: WorldRobotID, - goal: JointState, - planning_epoch: int, + @rpc + def plan_to_joint_targets( + self, joint_targets: Mapping[PlanningGroupID | PlanningGroup, JointState] ) -> bool: - """Plan path from current position to goal, store result.""" - assert self._world_monitor and self._planner # guaranteed by _begin_planning - self._dismiss_preview(robot_id) - start = self._world_monitor.get_current_joint_state(robot_id) - if start is None: - return self._fail("No joint state") - - # Trim goal to planner DOF (e.g. strip gripper joint from coordinator state) - planner_dof = len(start.position) - if len(goal.position) > planner_dof: - goal = JointState( - name=list(goal.name[:planner_dof]) if goal.name else [], - position=list(goal.position[:planner_dof]), - ) + """Plan to joint targets keyed by planning group.""" + if self._world_monitor is None or self._planner is None: + return False + if not joint_targets: + logger.error("At least one joint target is required") + return False - result = self._planner.plan_joint_path( - world=self._world_monitor.world, - robot_id=robot_id, - start=start, - goal=goal, - timeout=self.config.planning_timeout, + group_ids = tuple( + dict.fromkeys(planning_group_id_from_selector(group) for group in joint_targets) ) - if self._state != ManipulationState.PLANNING or planning_epoch != self._planning_epoch: - logger.info("Discarding cancelled planning result") + if not self._begin_planning(): return False - if not result.is_success(): - return self._fail(f"Planning failed: {result.status.name}") - - logger.info(f"Path: {len(result.path)} waypoints") - self._planned_paths[robot_name] = result.path + try: + start = self._selected_joint_state(group_ids) + except Exception as exc: + return self._fail(f"Failed to resolve planning groups: {exc}") + if start is None: + return self._fail("No joint state") - _, _, traj_gen = self._robots[robot_name] - # Convert JointState path to list of position lists for trajectory generator - traj = traj_gen.generate([list(state.position) for state in result.path]) - self._planned_trajectories[robot_name] = traj - logger.info(f"Trajectory: {traj.duration:.3f}s") + goal_names: list[str] = [] + goal_positions: list[float] = [] + for group, target in joint_targets.items(): + group_id = planning_group_id_from_selector(group) + target_global = self._joint_target_to_global_names(group_id, target) + if target_global is None: + return self._fail(f"Invalid joint target for '{group_id}'") + goal_names.extend(target_global.name) + goal_positions.extend(target_global.position) - self._state = ManipulationState.COMPLETED - return True + goal = JointState(name=goal_names, position=goal_positions) + return self._plan_selected_path(group_ids, start, goal) @rpc - def preview_path( + def preview_plan( self, + plan: GeneratedPlan | None = None, duration: float | None = None, robot_name: RobotName | None = None, - target_fps: float = 30.0, ) -> bool: - """Preview the planned path in the visualizer. - - Args: - duration: Total animation duration in seconds. Uses trajectory duration if None. - robot_name: Robot to preview (required if multiple robots configured) - target_fps: Nominal preview update rate. Set <= 0 to use planned waypoints directly. - """ + """Preview a generated plan, defaulting to `_last_plan` when omitted.""" if self._world_monitor is None: return False - - robot = self._get_robot(robot_name) - if robot is None: + plan = plan or self._last_plan + if plan is None or not plan.path: + logger.warning("No generated plan to preview") return False - robot_name, robot_id, _, _ = robot - - planned_path = self._planned_paths.get(robot_name) - if planned_path is None or len(planned_path) == 0: - logger.warning(f"No planned path to preview for {robot_name}") + if robot_name is not None and robot_name not in self._affected_robot_names(plan): + logger.error("Generated plan does not affect robot '%s'", robot_name) return False - - if duration is None: - trajectory = self._planned_trajectories.get(robot_name) - animation_duration = trajectory.duration if trajectory is not None else 3.0 - else: - trajectory = self._planned_trajectories.get(robot_name) - animation_duration = duration - - interpolated = list(planned_path) - if trajectory is not None and target_fps > 0 and animation_duration > 0: - times = np.array( - [point.time_from_start for point in trajectory.points], dtype=np.float64 - ) - positions = np.array([point.positions for point in trajectory.points], dtype=np.float64) - if len(times) > 1 and positions.ndim == 2 and times[-1] > times[0]: - frame_count = int(np.ceil(animation_duration * target_fps)) + 1 - sample_times = np.linspace(times[0], times[-1], frame_count) - joint_names = trajectory.joint_names or planned_path[0].name - sampled_positions = np.column_stack( - [ - np.interp(sample_times, times, positions[:, joint]) - for joint in range(positions.shape[1]) - ] - ) - interpolated = [ - JointState(name=joint_names, position=position.tolist()) - for position in sampled_positions - ] - self._world_monitor.animate_path(robot_id, interpolated, animation_duration) + animation_duration = duration if duration is not None else 1.0 + self._world_monitor.animate_plan(plan, animation_duration) return True @rpc @@ -708,13 +966,7 @@ def has_planned_path(self) -> bool: Returns: True if a path is planned and ready """ - robot = self._get_robot() - if robot is None: - return False - robot_name, _, _, _ = robot - - path = self._planned_paths.get(robot_name) - return path is not None and len(path) > 0 + return self._last_plan is not None and bool(self._last_plan.path) @rpc def get_visualization_url(self) -> str | None: @@ -734,15 +986,7 @@ def clear_planned_path(self) -> bool: Returns: True if cleared """ - if self._world_monitor is None: - return False - robot = self._get_robot() - if robot is None: - return False - robot_name, _, _, _ = robot - - self._planned_paths.pop(robot_name, None) - self._planned_trajectories.pop(robot_name, None) + self._last_plan = None return True @rpc @@ -755,7 +999,7 @@ def list_robots(self) -> list[str]: return list(self._robots.keys()) @rpc - def get_robot_info(self, robot_name: RobotName | None = None) -> dict[str, Any] | None: + def get_robot_info(self, robot_name: RobotName | None = None) -> RobotInfo | None: """Get information about a robot. Args: @@ -769,11 +1013,17 @@ def get_robot_info(self, robot_name: RobotName | None = None) -> dict[str, Any] return None robot_name, robot_id, config, _ = robot + planning_groups = ( + [group for group in self._world_monitor.planning_groups.groups_for_robot(robot_name)] + if self._world_monitor is not None + else [] + ) - return { + info: RobotInfo = { "name": config.name, "world_robot_id": robot_id, "joint_names": config.joint_names, + "planning_groups": planning_groups, "end_effector_link": config.end_effector_link, "base_link": config.base_link, "max_velocity": config.max_velocity, @@ -786,6 +1036,7 @@ def get_robot_info(self, robot_name: RobotName | None = None) -> dict[str, Any] if (init := self._init_joints.get(robot_name)) else None, } + return info def robot_items(self) -> list[tuple[RobotName, WorldRobotID, RobotModelConfig]]: """Return configured robots for in-process visualization adapters.""" @@ -820,97 +1071,6 @@ def get_init_joints(self, robot_name: RobotName | None = None) -> JointState | N return None return self._init_joints.get(robot[0]) - def evaluate_joint_target( - self, joints: JointState | None, robot_name: RobotName - ) -> TargetEvaluation: - """Evaluate a joint target for visualization without planning a path.""" - robot_id = self.robot_id_for_name(robot_name) - if robot_id is None or self._world_monitor is None: - return { - "success": False, - "status": "NO_ROBOT", - "message": f"Unknown robot: {robot_name}", - "collision_free": False, - "ee_pose": None, - "joint_state": None, - } - if joints is None: - return { - "success": False, - "status": "NO_TARGET", - "message": "No joint target provided", - "collision_free": False, - "ee_pose": None, - "joint_state": None, - } - target = JointState(joints) - collision_free = self._world_monitor.is_state_valid(robot_id, target) - return { - "success": True, - "status": "FEASIBLE" if collision_free else "COLLISION", - "message": "Target is collision-free" if collision_free else "Target is in collision", - "collision_free": collision_free, - "ee_pose": self._world_monitor.get_ee_pose(robot_id, target), - "joint_state": target, - } - - def evaluate_pose_target(self, pose: Pose, robot_name: RobotName) -> TargetEvaluation: - """Evaluate a Cartesian target for visualization without planning a path.""" - robot_id = self.robot_id_for_name(robot_name) - if robot_id is None: - return { - "success": False, - "joint_state": None, - "status": "UNKNOWN_ROBOT", - "message": f"Unknown robot: {robot_name}", - "collision_free": False, - } - if self._world_monitor is None or self._kinematics is None: - return { - "success": False, - "joint_state": None, - "status": "UNAVAILABLE", - "message": "Planning is not initialized or current state is unavailable", - "collision_free": False, - } - current = self._world_monitor.get_current_joint_state(robot_id) - if current is None: - return { - "success": False, - "joint_state": None, - "status": "UNAVAILABLE", - "message": "Planning is not initialized or current state is unavailable", - "collision_free": False, - } - ik = self._solve_ik_for_pose(robot_id, pose, current, check_collision=True) - joint_state = JointState(ik.joint_state) if ik.is_success() and ik.joint_state else None - collision_free = bool( - joint_state is not None and self._world_monitor.is_state_valid(robot_id, joint_state) - ) - return { - "success": joint_state is not None and collision_free, - "joint_state": joint_state, - "status": ik.status.name, - "message": ik.message, - "position_error": ik.position_error, - "orientation_error": ik.orientation_error, - "collision_free": collision_free, - } - - def get_planned_path(self, robot_name: RobotName) -> JointPath | None: - """Return a copy of the stored planned path for visualization.""" - path = self._planned_paths.get(robot_name) - if path is None: - return None - return [JointState(point) for point in path] - - def get_planned_trajectory_duration(self, robot_name: RobotName) -> float | None: - """Return the stored planned trajectory duration for visualization.""" - trajectory = self._planned_trajectories.get(robot_name) - if trajectory is None: - return None - return float(trajectory.duration) - @rpc def set_init_joints(self, joint_state: JointState, robot_name: RobotName | None = None) -> bool: """Set the init joint state. @@ -922,13 +1082,46 @@ def set_init_joints(self, joint_state: JointState, robot_name: RobotName | None robot = self._get_robot(robot_name) if robot is None: return False - self._init_joints[robot[0]] = joint_state + robot_name_resolved, _, config, _ = robot + try: + normalized = self._local_robot_joint_state(config, joint_state) + except ValueError as exc: + logger.error(str(exc)) + return False + self._init_joints[robot_name_resolved] = normalized logger.info( - f"Init joints set for '{robot[0]}': " - f"[{', '.join(f'{j:.3f}' for j in joint_state.position)}]" + f"Init joints set for '{robot_name_resolved}': " + f"[{', '.join(f'{j:.3f}' for j in normalized.position)}]" ) return True + def _local_robot_joint_state( + self, config: RobotModelConfig, joint_state: JointState + ) -> JointState: + """Normalize a robot-scoped joint state to local model joint order.""" + if not joint_state.name: + if len(joint_state.position) != len(config.joint_names): + raise ValueError( + f"JointState has {len(joint_state.position)} positions, " + f"expected {len(config.joint_names)} for robot '{config.name}'" + ) + return JointState(name=list(config.joint_names), position=list(joint_state.position)) + + assert_local_joint_names(joint_state.name) + positions_by_name = dict(zip(joint_state.name, joint_state.position, strict=False)) + missing = [name for name in config.joint_names if name not in positions_by_name] + if missing: + raise ValueError(f"JointState for robot '{config.name}' is missing joints: {missing}") + extra = set(joint_state.name) - set(config.joint_names) + if extra: + raise ValueError( + f"JointState for robot '{config.name}' has extra joints: {sorted(extra)}" + ) + return JointState( + name=list(config.joint_names), + position=[positions_by_name[name] for name in config.joint_names], + ) + @rpc def set_init_joints_to_current(self, robot_name: RobotName | None = None) -> bool: """Set init joints to the current joint positions. @@ -966,72 +1159,163 @@ def _get_coordinator_client(self) -> RPCClient | None: self._coordinator_client = RPCClient(None, ControlCoordinator) return self._coordinator_client - def _translate_trajectory_to_coordinator( + def _invoke_coordinator_task( self, - trajectory: JointTrajectory, - robot_config: RobotModelConfig, - ) -> JointTrajectory: - """Translate trajectory joint names from URDF to coordinator namespace. - - Args: - trajectory: Trajectory with URDF joint names - robot_config: Robot config with joint name mapping - - Returns: - Trajectory with coordinator joint names - """ - if not robot_config.joint_name_mapping: - return trajectory # No translation needed - - # Translate joint names - coordinator_names = [ - robot_config.get_coordinator_joint_name(j) for j in trajectory.joint_names - ] - - # Create new trajectory with translated names - # Note: duration is computed automatically from points in JointTrajectory.__init__ - return JointTrajectory( - joint_names=coordinator_names, - points=trajectory.points, - timestamp=trajectory.timestamp, - ) + client: RPCClient, + task_name: str, + method: str, + kwargs: dict[str, Any], + ) -> Any: + """Invoke a ControlCoordinator task with an execution-specific timeout.""" + remote_name = getattr(client, "remote_name", None) + rpc_client = getattr(client, "rpc", None) + call_sync = getattr(rpc_client, "call_sync", None) + if isinstance(remote_name, str) and callable(call_sync): + result, unsub_fn = call_sync( + f"{remote_name}/task_invoke", + ([task_name, method, kwargs], {}), + rpc_timeout=self.config.coordinator_rpc_timeout, + ) + unsub_fns = getattr(client, "_unsub_fns", None) + if isinstance(unsub_fns, list): + unsub_fns.append(unsub_fn) + return result + return client.task_invoke(task_name, method, kwargs) @rpc - def execute(self, robot_name: RobotName | None = None) -> bool: + def execute(self) -> bool: """Execute planned trajectory via ControlCoordinator.""" - if (robot := self._get_robot(robot_name)) is None: - return False - robot_name, _, config, _ = robot + return self.execute_plan(self._last_plan) - if (traj := self._planned_trajectories.get(robot_name)) is None: - logger.warning("No planned trajectory") - return False - if not config.coordinator_task_name: - logger.error(f"No coordinator_task_name for '{robot_name}'") + @rpc + def execute_plan(self, plan: GeneratedPlan | None = None) -> bool: + """Project and execute a generated plan through affected trajectory tasks. + + TODO: proper time parametrization. + """ + plan = plan or self._last_plan + if plan is None or not plan.path: + logger.warning("No generated plan") return False if (client := self._get_coordinator_client()) is None: logger.error("No coordinator client") return False - translated = self._translate_trajectory_to_coordinator(traj, config) + try: + affected = self._affected_robot_names(plan) + except Exception as exc: + return self._fail(f"Failed to resolve generated plan: {exc}") logger.info( - f"Executing: task='{config.coordinator_task_name}', {len(translated.points)} pts, {translated.duration:.2f}s" + "Execute plan: groups=%s, affected=%s", + plan.group_ids, + affected, ) + assert self._world_monitor is not None + + dispatches: list[tuple[RobotName, str, RobotModelConfig, JointTrajectory]] = [] + for name in affected: + robot = self._get_robot(name) + if robot is None: + return False + resolved_name, robot_id, config, traj_gen = robot + task_name = config.coordinator_task_name + if not task_name: + logger.error(f"No coordinator_task_name for '{resolved_name}'") + return False + + current = self._world_monitor.get_current_joint_state(robot_id) + current_by_name = ( + dict(zip(current.name, current.position, strict=False)) + if current is not None + else {} + ) + + global_joint_names = make_global_joint_names(resolved_name, config.joint_names) + local_path: list[JointState] = [] + for waypoint in plan.path: + if len(waypoint.name) != len(waypoint.position): + logger.error( + "Cannot execute plan for '%s': waypoint has %d names but %d positions", + resolved_name, + len(waypoint.name), + len(waypoint.position), + ) + return False + try: + assert_global_joint_names(waypoint.name) + except ValueError as exc: + logger.error("Cannot execute plan for '%s': %s", resolved_name, exc) + return False + selected_positions = dict(zip(waypoint.name, waypoint.position, strict=True)) + positions: list[float] = [] + for local_name, global_name in zip( + config.joint_names, global_joint_names, strict=True + ): + if global_name in selected_positions: + positions.append(selected_positions[global_name]) + elif local_name in current_by_name: + positions.append(current_by_name[local_name]) + else: + logger.error( + "Cannot execute plan for '%s': missing joint '%s'", + resolved_name, + global_name, + ) + return False + local_path.append(JointState(name=list(config.joint_names), position=positions)) + if len(local_path) < 2: + logger.error("Plan projection for '%s' has fewer than two waypoints", resolved_name) + return False + local_trajectory = traj_gen.generate([list(state.position) for state in local_path]) + trajectory = JointTrajectory( + joint_names=list(global_joint_names), + points=local_trajectory.points, + timestamp=local_trajectory.timestamp, + ) + dispatches.append((resolved_name, task_name, config, trajectory)) self._state = ManipulationState.EXECUTING - result = client.task_invoke( - config.coordinator_task_name, "execute", {"trajectory": translated} - ) - if result: - logger.info("Trajectory accepted") - self._state = ManipulationState.COMPLETED - return True - else: - return self._fail("Coordinator rejected trajectory") + for _name, task_name, config, trajectory in dispatches: + logger.info( + "Executing: task='%s', %d pts, %.2fs", + task_name, + len(trajectory.points), + trajectory.duration, + ) + try: + result = self._invoke_coordinator_task( + client, + task_name, + "execute", + {"trajectory": trajectory}, + ) + except TimeoutError as exc: + return self._fail(f"Coordinator RPC timed out for task '{task_name}': {exc}") + except Exception as exc: + return self._fail(f"Coordinator RPC failed for task '{task_name}': {exc}") + logger.info( + "Coordinator execute result: task='%s', result=%r", + config.coordinator_task_name, + result, + ) + if not result: + return self._fail("Coordinator rejected trajectory") + + logger.info("Trajectory accepted") + self._state = ManipulationState.COMPLETED + return True @rpc def get_trajectory_status(self, robot_name: RobotName | None = None) -> dict[str, Any] | None: """Get trajectory execution status via coordinator task_invoke.""" + last_plan = self._last_plan + if robot_name is None and last_plan is not None and last_plan.path: + statuses = { + name: self.get_trajectory_status(name) + for name in self._affected_robot_names(last_plan) + } + return {"robots": statuses} + if (robot := self._get_robot(robot_name)) is None: return None _, _, config, _ = robot @@ -1080,9 +1364,6 @@ def add_obstacle( logger.warning("mesh_path required for mesh obstacles") return "" - # Import PoseStamped here to avoid circular imports - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - obstacle = Obstacle( name=name, obstacle_type=obstacle_type, @@ -1189,6 +1470,18 @@ def _wait_for_trajectory_completion( Returns: True if trajectory completed successfully """ + last_plan = self._last_plan + if robot_name is None and last_plan is not None and last_plan.path: + try: + robot_names = self._affected_robot_names(last_plan) + except Exception as exc: + logger.warning("Failed to resolve generated plan while waiting: %s", exc) + return False + return all( + self._wait_for_trajectory_completion(name, timeout, poll_interval) + for name in robot_names + ) + robot = self._get_robot(robot_name) if robot is None: return True @@ -1196,11 +1489,7 @@ def _wait_for_trajectory_completion( client = self._get_coordinator_client() if client is None or not config.coordinator_task_name: - # No coordinator — wait for trajectory duration as fallback - traj = self._planned_trajectories.get(rname) - if traj is not None: - logger.info(f"No coordinator status — waiting {traj.duration:.1f}s for trajectory") - time.sleep(traj.duration + 0.5) + logger.info("No coordinator status available for '%s'", rname) return True # Poll task state via task_invoke @@ -1221,13 +1510,7 @@ def _wait_for_trajectory_completion( # task_invoke returned None — task not found, assume done return True except Exception: - # Fallback: wait for trajectory duration - traj = self._planned_trajectories.get(rname) - if traj is not None: - remaining = traj.duration - (time.time() - start) - if remaining > 0: - logger.info(f"Status poll failed — waiting {remaining:.1f}s for trajectory") - time.sleep(remaining + 0.5) + logger.info("Status poll failed for '%s'", rname) return True time.sleep(poll_interval) @@ -1262,10 +1545,10 @@ def _preview_execute_wait( preview_duration: Duration to animate the preview in Meshcat (seconds) """ logger.info("Previewing trajectory...") - self.preview_path(preview_duration, robot_name) + self.preview_plan(duration=preview_duration, robot_name=robot_name) logger.info("Executing trajectory...") - if not self.execute(robot_name): + if not self.execute(): return SkillResult.fail("EXECUTION_FAILED", "Trajectory execution failed") if not self._wait_for_trajectory_completion(robot_name): diff --git a/dimos/manipulation/planning/README.md b/dimos/manipulation/planning/README.md index 6e5a53eb57..96d8176268 100644 --- a/dimos/manipulation/planning/README.md +++ b/dimos/manipulation/planning/README.md @@ -70,7 +70,6 @@ config = RobotModelConfig( joint_names=["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"], end_effector_link="link7", base_link="link_base", - joint_name_mapping={"arm_joint1": "joint1", ...}, # coordinator <-> URDF coordinator_task_name="traj_arm", ) @@ -93,12 +92,11 @@ module.execute() # Sends to coordinator | `name` | Robot identifier | | `model_path` | Path to URDF/XACRO file | | `base_pose` | PoseStamped for robot base in world frame | -| `joint_names` | Joint names in URDF | +| `joint_names` | Ordered controllable local model joint names | | `end_effector_link` | EE link name | | `base_link` | Base link name | | `max_velocity` | Max joint velocity (rad/s) | | `max_acceleration` | Max acceleration (rad/s²) | -| `joint_name_mapping` | Coordinator → URDF name mapping | | `coordinator_task_name` | Task name for execution RPC | | `package_paths` | ROS package paths for meshes | | `xacro_args` | Xacro arguments (e.g., `{"dof": "7"}`) | @@ -141,7 +139,7 @@ accepted. |-----------|-------------| | `xarm6_planner_only` | XArm 6-DOF standalone (no coordinator) | | `xarm7-planner-coordinator` | XArm 7-DOF with coordinator | -| `dual-xarm6-planner` | Dual XArm 6-DOF | +| `dual-xarm6-planner-coordinator` | Dual XArm 6-DOF with coordinator | | `xarm-perception-sim` | XArm 7-DOF simulation perception stack | ## Directory Structure diff --git a/dimos/manipulation/planning/examples/manipulation_client.py b/dimos/manipulation/planning/examples/manipulation_client.py index 1185f28f21..f13515f582 100644 --- a/dimos/manipulation/planning/examples/manipulation_client.py +++ b/dimos/manipulation/planning/examples/manipulation_client.py @@ -124,7 +124,6 @@ def ik_pose( pitch: float | None = None, yaw: float | None = None, robot_name: str | None = None, - check_collision: bool = True, seed_joints: list[float] | JointState | None = None, ) -> IKResult: """Solve IK for a Cartesian pose without path planning. @@ -137,13 +136,12 @@ def ik_pose( pitch: Optional target pitch. Preserves current orientation if omitted. yaw: Optional target yaw. Preserves current orientation if omitted. robot_name: Robot to solve for when multiple robots are configured. - check_collision: Whether to reject IK candidates in collision. seed_joints: Optional initial joint configuration for local IK. Pass either a list of joint positions in robot joint order or a named JointState. """ target = _make_target_pose(x, y, z, roll, pitch, yaw, robot_name) seed = _make_seed_joint_state(seed_joints, robot_name) - return _client.solve_ik(target, robot_name, check_collision, seed) + return _client.inverse_kinematics_single(target, robot_name, seed) def plan_pose( @@ -163,15 +161,14 @@ def plan_pose( def preview( duration: float | None = None, robot_name: str | None = None, - target_fps: float = 30.0, ) -> bool: - """Preview planned path in Meshcat.""" - return _client.preview_path(duration, robot_name, target_fps) + """Preview the last generated plan in Visualizer.""" + return _client.preview_plan(None, duration, robot_name) -def execute(robot_name: str | None = None) -> bool: +def execute() -> bool: """Execute planned trajectory via coordinator.""" - return _client.execute(robot_name) + return _client.execute() def home(robot_name: str | None = None) -> bool: @@ -181,7 +178,7 @@ def home(robot_name: str | None = None) -> bool: home_joints = _client.get_robot_info(robot_name).get("home_joints", [0.0] * 7) success = _client.plan_to_joints(JointState(position=home_joints), robot_name) if success: - return _client.execute(robot_name) + return _client.execute() return False diff --git a/dimos/manipulation/planning/groups/discovery.py b/dimos/manipulation/planning/groups/discovery.py new file mode 100644 index 0000000000..5e9ddfa25d --- /dev/null +++ b/dimos/manipulation/planning/groups/discovery.py @@ -0,0 +1,358 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Planning group discovery from SRDF or conservative model fallback.""" + +from __future__ import annotations + +import itertools +from pathlib import Path +import warnings +import xml.etree.ElementTree as ET + +from dimos.manipulation.planning.groups.models import PlanningGroupDefinition +from dimos.robot.model_parser import JointDescription, ModelDescription +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +FALLBACK_PLANNING_GROUP_NAME = "manipulator" + + +class PlanningGroupDiscoveryError(ValueError): + """Raised when planning groups cannot be discovered for a model.""" + + +def _warn(message: str) -> None: + logger.warning(message) + warnings.warn(message, UserWarning, stacklevel=2) + + +def discover_planning_group_definitions( + *, + robot_name: str, + model_path: Path, + model: ModelDescription, + controllable_joint_names: list[str], + srdf_path: Path | None = None, +) -> list[PlanningGroupDefinition]: + """Discover planning groups from SRDF or fallback generation. + + Precedence is explicit SRDF path, conservative auto-discovery with warning, + then fallback generation from the controllable joint set. + """ + resolved_srdf_path = _resolve_srdf_path(model_path, srdf_path) + if resolved_srdf_path is not None: + groups = parse_srdf_planning_groups( + resolved_srdf_path, + model=model, + controllable_joint_names=controllable_joint_names, + ) + if groups: + return groups + _warn( + f"No supported planning groups found in SRDF {resolved_srdf_path} " + f"for robot {robot_name}; trying fallback generation" + ) + + return [ + generate_fallback_planning_group( + model=model, + controllable_joint_names=controllable_joint_names, + ) + ] + + +def parse_srdf_planning_groups( + srdf_path: Path, + *, + model: ModelDescription, + controllable_joint_names: list[str], +) -> list[PlanningGroupDefinition]: + """Extract supported SRDF planning group definitions. + + Supported forms are a single ```` + child or an ordered list of ```` children. Other forms, + including SRDF ```` metadata, are ignored for planning group + extraction. This is intentionally a minimal SRDF group extractor rather + than a full SRDF parser; adopting a ROS/MoveIt parser such as srdfdom would + add substantial dependency overhead for this narrow subset. + """ + root = ET.parse(srdf_path).getroot() + groups: list[PlanningGroupDefinition] = [] + for group_elem in root.findall("group"): + group_name = group_elem.get("name") + if not group_name: + _warn(f"Skipping SRDF group without a name in {srdf_path}") + continue + + children = [child for child in list(group_elem) if isinstance(child.tag, str)] + chain_children = [child for child in children if child.tag == "chain"] + joint_children = [child for child in children if child.tag == "joint"] + unsupported_children = [child for child in children if child.tag not in {"chain", "joint"}] + + if len(chain_children) == 1 and not joint_children and not unsupported_children: + definition = _parse_chain_group( + group_name, + chain_children[0], + model=model, + controllable_joint_names=controllable_joint_names, + srdf_path=srdf_path, + ) + elif joint_children and len(joint_children) == len(children): + definition = _parse_joint_list_group( + group_name, + joint_children, + model=model, + controllable_joint_names=controllable_joint_names, + srdf_path=srdf_path, + ) + else: + child_tags = [child.tag for child in children] + _warn( + f"Skipping unsupported SRDF planning group {group_name} in " + f"{srdf_path} with child tags {child_tags}" + ) + definition = None + + if definition is not None: + groups.append(definition) + + return groups + + +def generate_fallback_planning_group( + *, + model: ModelDescription, + controllable_joint_names: list[str], +) -> PlanningGroupDefinition: + """Generate one conservative fallback planning group named ``manipulator``.""" + ordered_joints = _validate_and_order_serial_joints(model, controllable_joint_names) + while ordered_joints and ordered_joints[-1].type == "prismatic": + removed = ordered_joints.pop() + _warn( + f"Excluding terminal prismatic joint {removed.name} from " + f"fallback planning group {FALLBACK_PLANNING_GROUP_NAME}" + ) + + if not ordered_joints: + raise PlanningGroupDiscoveryError( + "Fallback planning group generation removed all candidate joints; provide SRDF" + ) + + return PlanningGroupDefinition( + name=FALLBACK_PLANNING_GROUP_NAME, + joint_names=tuple(joint.name for joint in ordered_joints), + base_link=ordered_joints[0].parent_link, + tip_link=ordered_joints[-1].child_link, + source="fallback", + ) + + +def _resolve_srdf_path(model_path: Path, srdf_path: Path | None) -> Path | None: + if srdf_path is not None: + if srdf_path.exists(): + return srdf_path + raise FileNotFoundError(f"SRDF file not found: {srdf_path}") + + for candidate in _srdf_auto_discovery_candidates(model_path): + if candidate.exists(): + _warn(f"Auto-discovered SRDF at {candidate}") + return candidate + return None + + +def _srdf_auto_discovery_candidates(model_path: Path) -> list[Path]: + candidates: list[Path] = [] + name = model_path.name + if name.endswith(".urdf.xacro"): + candidates.append(model_path.with_name(name.removesuffix(".urdf.xacro") + ".srdf")) + elif model_path.suffix: + candidates.append(model_path.with_suffix(".srdf")) + candidates.append(model_path.parent / "config" / "robot.srdf") + candidates.append(model_path.parent.parent / "config" / "robot.srdf") + return list(dict.fromkeys(candidates)) + + +def _parse_chain_group( + group_name: str, + chain_elem: ET.Element, + *, + model: ModelDescription, + controllable_joint_names: list[str], + srdf_path: Path, +) -> PlanningGroupDefinition | None: + base_link = chain_elem.get("base_link") + tip_link = chain_elem.get("tip_link") + if not base_link or not tip_link: + _warn( + f"Skipping SRDF chain group {group_name} in {srdf_path} because " + "base_link or tip_link is missing" + ) + return None + + try: + ordered_joints = _ordered_joints_between_links(model, base_link, tip_link) + controlled_joints = [joint for joint in ordered_joints if joint.type != "fixed"] + _validate_controllable(group_name, controlled_joints, controllable_joint_names) + except PlanningGroupDiscoveryError as exc: + _warn(f"Skipping SRDF chain group {group_name} in {srdf_path}: {exc}") + return None + + return PlanningGroupDefinition( + name=group_name, + joint_names=tuple(joint.name for joint in controlled_joints), + base_link=base_link, + tip_link=tip_link, + source="srdf", + ) + + +def _parse_joint_list_group( + group_name: str, + joint_children: list[ET.Element], + *, + model: ModelDescription, + controllable_joint_names: list[str], + srdf_path: Path, +) -> PlanningGroupDefinition | None: + joint_names = [child.get("name", "") for child in joint_children] + if any(not name for name in joint_names): + _warn(f"Skipping SRDF joint-list group {group_name} in {srdf_path} with empty joint name") + return None + try: + ordered_joints = _validate_ordered_serial_joints(model, joint_names) + _validate_controllable(group_name, ordered_joints, controllable_joint_names) + except PlanningGroupDiscoveryError as exc: + _warn(f"Skipping SRDF joint-list group {group_name} in {srdf_path}: {exc}") + return None + + return PlanningGroupDefinition( + name=group_name, + joint_names=tuple(joint.name for joint in ordered_joints), + base_link=ordered_joints[0].parent_link, + tip_link=ordered_joints[-1].child_link, + source="srdf", + ) + + +def _ordered_joints_between_links( + model: ModelDescription, + base_link: str, + tip_link: str, +) -> list[JointDescription]: + joints_by_parent: dict[str, list[JointDescription]] = {} + for joint in model.joints: + joints_by_parent.setdefault(joint.parent_link, []).append(joint) + + ordered_joints: list[JointDescription] = [] + current_link = base_link + visited_links = {base_link} + while current_link != tip_link: + children = joints_by_parent.get(current_link, []) + if len(children) != 1: + raise PlanningGroupDiscoveryError( + f"chain from {base_link} to {tip_link} is branching or disconnected at {current_link}" + ) + joint = children[0] + ordered_joints.append(joint) + current_link = joint.child_link + if current_link in visited_links: + raise PlanningGroupDiscoveryError("chain contains a cycle") + visited_links.add(current_link) + + return ordered_joints + + +def _validate_ordered_serial_joints( + model: ModelDescription, + joint_names: list[str], +) -> list[JointDescription]: + ordered_joints: list[JointDescription] = [] + for joint_name in joint_names: + joint = model.get_joint(joint_name) + if joint is None: + raise PlanningGroupDiscoveryError(f"joint {joint_name} does not exist in model") + if joint.type == "fixed": + raise PlanningGroupDiscoveryError(f"joint {joint_name} is fixed") + ordered_joints.append(joint) + + if not ordered_joints: + raise PlanningGroupDiscoveryError("planning group contains no joints") + + for previous, current in itertools.pairwise(ordered_joints): + if previous.child_link != current.parent_link: + raise PlanningGroupDiscoveryError( + f"joints {previous.name} and {current.name} are not adjacent in a serial chain" + ) + return ordered_joints + + +def _validate_and_order_serial_joints( + model: ModelDescription, + joint_names: list[str], +) -> list[JointDescription]: + if not joint_names: + raise PlanningGroupDiscoveryError("fallback requires at least one controllable joint") + + joints: list[JointDescription] = [] + for joint_name in joint_names: + joint = model.get_joint(joint_name) + if joint is None: + raise PlanningGroupDiscoveryError(f"joint {joint_name} does not exist in model") + if joint.type == "fixed": + raise PlanningGroupDiscoveryError(f"joint {joint_name} is fixed") + joints.append(joint) + + by_parent = {joint.parent_link: joint for joint in joints} + by_child = {joint.child_link: joint for joint in joints} + if len(by_parent) != len(joints) or len(by_child) != len(joints): + raise PlanningGroupDiscoveryError("controllable joints branch or merge; provide SRDF") + + starts = [joint for joint in joints if joint.parent_link not in by_child] + ends = [joint for joint in joints if joint.child_link not in by_parent] + if len(starts) != 1 or len(ends) != 1: + raise PlanningGroupDiscoveryError( + "controllable joints are disconnected or cyclic; provide SRDF" + ) + + ordered_joints: list[JointDescription] = [] + current = starts[0] + while True: + ordered_joints.append(current) + next_joint = by_parent.get(current.child_link) + if next_joint is None: + break + current = next_joint + + if len(ordered_joints) != len(joints): + raise PlanningGroupDiscoveryError("controllable joints are disconnected; provide SRDF") + return ordered_joints + + +def _validate_controllable( + group_name: str, + joints: list[JointDescription], + controllable_joint_names: list[str], +) -> None: + if not joints: + raise PlanningGroupDiscoveryError( + f"planning group {group_name} contains no controllable joints" + ) + controllable = set(controllable_joint_names) + missing = [joint.name for joint in joints if joint.name not in controllable] + if missing: + raise PlanningGroupDiscoveryError( + f"planning group {group_name} includes joints outside controllable set: {missing}" + ) diff --git a/dimos/manipulation/planning/groups/identifiers.py b/dimos/manipulation/planning/groups/identifiers.py new file mode 100644 index 0000000000..7ce0fc0774 --- /dev/null +++ b/dimos/manipulation/planning/groups/identifiers.py @@ -0,0 +1,129 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""String grammar helpers for public manipulation identifiers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from dimos.manipulation.planning.spec.models import ( + GlobalJointName, + LocalModelJointName, + PlanningGroupID, + RobotName, +) + + +def assert_valid_robot_name(robot_name: RobotName) -> None: + """Validate a robot name for delimiter-based public IDs.""" + if not robot_name or "/" in robot_name: + raise ValueError(f"Invalid robot name: {robot_name!r}") + + +def assert_valid_local_joint_name(local_joint_name: LocalModelJointName) -> None: + """Validate a local model joint name for delimiter-based global joint names.""" + if not local_joint_name or "/" in local_joint_name: + raise ValueError(f"Invalid local joint name: {local_joint_name!r}") + + +def assert_local_joint_names(names: Sequence[LocalModelJointName]) -> None: + """Validate that names are local model joint names, not global joint names.""" + for name in names: + assert_valid_local_joint_name(name) + + +def make_planning_group_id(robot_name: RobotName, group_name: str) -> PlanningGroupID: + """Build a public planning group ID.""" + assert_valid_robot_name(robot_name) + if not group_name or "/" in group_name: + raise ValueError(f"Invalid planning group name: {group_name!r}") + return f"{robot_name}/{group_name}" + + +def parse_planning_group_id(group_id: PlanningGroupID) -> tuple[RobotName, str]: + """Split and validate a planning group ID.""" + parts = group_id.split("/", maxsplit=1) + if len(parts) != 2 or not parts[0] or not parts[1] or "/" in parts[1]: + raise ValueError( + f"Invalid planning group ID {group_id!r}; expected '{{robot_name}}/{{group_name}}'" + ) + return parts[0], parts[1] + + +def make_global_joint_name( + robot_name: RobotName, + local_joint_name: LocalModelJointName, +) -> GlobalJointName: + """Convert a local model joint name to a public global joint name.""" + assert_valid_robot_name(robot_name) + assert_valid_local_joint_name(local_joint_name) + return f"{robot_name}/{local_joint_name}" + + +def make_global_joint_names( + robot_name: RobotName, + local_joint_names: Sequence[LocalModelJointName], +) -> list[GlobalJointName]: + """Convert local model joint names to public global joint names.""" + return [make_global_joint_name(robot_name, name) for name in local_joint_names] + + +def is_global_joint_name(name: str) -> bool: + """Return whether name has the exact global joint-name shape.""" + parts = name.split("/") + return len(parts) == 2 and bool(parts[0]) and bool(parts[1]) + + +def parse_global_joint_name( + global_joint_name: GlobalJointName, +) -> tuple[RobotName, LocalModelJointName]: + """Split and validate a global joint name.""" + parts = global_joint_name.split("/", maxsplit=1) + if len(parts) != 2: + raise ValueError( + f"Invalid global joint name {global_joint_name!r}; " + "expected '{robot_name}/{local_joint_name}'" + ) + robot_name, local_name = parts + try: + assert_valid_robot_name(robot_name) + assert_valid_local_joint_name(local_name) + except ValueError as exc: + raise ValueError( + f"Invalid global joint name {global_joint_name!r}; " + "expected '{robot_name}/{local_joint_name}'" + ) from exc + return robot_name, local_name + + +def assert_global_joint_names(names: Sequence[GlobalJointName]) -> None: + """Validate that names are global joint names.""" + invalid = [name for name in names if not is_global_joint_name(name)] + if invalid: + raise ValueError(f"Expected global joint names; got invalid names: {invalid}") + + +def local_joint_name_from_global( + robot_name: RobotName, + global_joint_name: GlobalJointName, +) -> LocalModelJointName: + """Validate and strip a global joint name for backend internals.""" + assert_valid_robot_name(robot_name) + parsed_robot_name, local_name = parse_global_joint_name(global_joint_name) + if parsed_robot_name != robot_name: + raise ValueError( + f"Global joint name {global_joint_name!r} does not belong to robot {robot_name!r}" + ) + return local_name diff --git a/dimos/manipulation/planning/groups/joints.py b/dimos/manipulation/planning/groups/joints.py new file mode 100644 index 0000000000..c66b4410d6 --- /dev/null +++ b/dimos/manipulation/planning/groups/joints.py @@ -0,0 +1,136 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Planning-group joint target and joint-state helpers.""" + +from collections.abc import Mapping, Sequence + +from dimos.manipulation.planning.groups.identifiers import ( + assert_global_joint_names, + assert_local_joint_names, + is_global_joint_name, +) +from dimos.manipulation.planning.groups.models import PlanningGroup +from dimos.manipulation.planning.spec.models import ( + GlobalJointName, + LocalModelJointName, + PlanningGroupID, +) +from dimos.msgs.sensor_msgs.JointState import JointState + + +def planning_group_id_from_selector(selector: PlanningGroupID | PlanningGroup) -> PlanningGroupID: + """Return the planning-group ID represented by a selector.""" + if isinstance(selector, PlanningGroup): + return selector.id + return selector + + +def matching_global_joint_name( + positions_by_name: Mapping[str, float], local_joint_name: LocalModelJointName +) -> GlobalJointName | None: + """Find the unique global joint name ending with a local joint name.""" + suffix = f"/{local_joint_name}" + matches = [name for name in positions_by_name if name.endswith(suffix)] + if len(matches) == 1: + return matches[0] + return None + + +def filter_joint_state_to_selected_joints( + joint_state: JointState, + global_joint_names: Sequence[GlobalJointName], + local_joint_names: Sequence[LocalModelJointName] = (), +) -> JointState: + """Project a joint state to selected global joints. + + Values are looked up by global name first. When ``local_joint_names`` is + provided, each corresponding local name is used as a fallback. + """ + if local_joint_names and len(global_joint_names) != len(local_joint_names): + raise ValueError("Global and local selected joint lists must have the same length") + + positions_by_name = dict(zip(joint_state.name, joint_state.position, strict=True)) + selected_positions: list[float] = [] + missing: list[str] = [] + for index, global_name in enumerate(global_joint_names): + if global_name in positions_by_name: + selected_positions.append(float(positions_by_name[global_name])) + continue + if local_joint_names: + local_name = local_joint_names[index] + if local_name in positions_by_name: + selected_positions.append(float(positions_by_name[local_name])) + continue + missing.append(global_name) + + if missing: + raise ValueError(f"IK result is missing selected joints: {missing}") + + return JointState({"name": list(global_joint_names), "position": selected_positions}) + + +def joint_target_to_global_names( + group: PlanningGroup, + target: JointState, +) -> JointState: + """Convert a group joint target to global joint names in group order. + + Named targets may use either the public global planning names or the + robot-local model names used by legacy robot-scoped callers, but the two + namespaces must not be mixed in one target. + """ + if not target.name: + if len(target.position) != len(group.joint_names): + raise ValueError( + f"Target for '{group.id}' has {len(target.position)} positions, " + f"expected {len(group.joint_names)}" + ) + return JointState(name=list(group.joint_names), position=list(target.position)) + + if len(target.name) != len(target.position): + raise ValueError( + f"Target for '{group.id}' has {len(target.name)} names but " + f"{len(target.position)} positions" + ) + + target_names = list(target.name) + global_flags = [is_global_joint_name(name) for name in target_names] + if any(global_flags) and not all(global_flags): + raise ValueError( + f"Target for '{group.id}' mixes global and local joint names: {target_names}" + ) + + if all(global_flags): + assert_global_joint_names(target_names) + expected_names = group.joint_names + else: + assert_local_joint_names(target_names) + expected_names = group.local_joint_names + + positions_by_name = dict(zip(target_names, target.position, strict=True)) + global_positions: list[float] = [] + missing: list[str] = [] + for expected_name in expected_names: + if expected_name in positions_by_name: + global_positions.append(positions_by_name[expected_name]) + else: + missing.append(expected_name) + if missing: + raise ValueError(f"Target for '{group.id}' is missing joints: {missing}") + + extra = set(target_names) - set(expected_names) + if extra: + raise ValueError(f"Target for '{group.id}' has extra joints: {sorted(extra)}") + return JointState(name=list(group.joint_names), position=global_positions) diff --git a/dimos/manipulation/planning/groups/models.py b/dimos/manipulation/planning/groups/models.py new file mode 100644 index 0000000000..c08b1bd4b5 --- /dev/null +++ b/dimos/manipulation/planning/groups/models.py @@ -0,0 +1,112 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Backend-independent planning-group domain models.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, TypeAlias + +from dimos.manipulation.planning.spec.models import ( + GlobalJointName, + LocalModelJointName, + PlanningGroupID, + RobotName, +) + +PlanningGroupSource: TypeAlias = Literal["srdf", "fallback"] + + +@dataclass(frozen=True) +class PlanningGroupDefinition: + """Model-level declaration of a planning group. + + Joint names are local model names. The definition is safe to store on + ``RobotModelConfig`` and is not bound to any runtime world robot ID. + """ + + name: str + joint_names: tuple[LocalModelJointName, ...] + base_link: str + tip_link: str | None = None + source: PlanningGroupSource = "srdf" + + @property + def has_pose_target(self) -> bool: + """Whether this group has a valid pose target frame.""" + return self.tip_link is not None + + +@dataclass(frozen=True) +class PlanningGroup: + """Public backend-independent planning group. + + A planning group exposes stable public IDs and global joint names for + planning APIs. It intentionally does not include backend runtime robot IDs. + """ + + id: PlanningGroupID + robot_name: RobotName + group_name: str + joint_names: tuple[GlobalJointName, ...] + local_joint_names: tuple[LocalModelJointName, ...] + base_link: str + tip_link: str | None = None + source: PlanningGroupSource = "srdf" + + @property + def has_pose_target(self) -> bool: + """Whether this group can be directly pose-targeted.""" + return self.tip_link is not None + + +@dataclass(frozen=True) +class PlanningGroupSelection: + """Validated ordered selection of planning groups. + + Selection validates ID existence and selected-joint overlap outside any + world backend. Requested group order is preserved. + """ + + groups: tuple[PlanningGroup, ...] + group_ids: tuple[PlanningGroupID, ...] + joint_names: tuple[GlobalJointName, ...] + robot_names: tuple[RobotName, ...] + + @classmethod + def from_groups(cls, groups: tuple[PlanningGroup, ...]) -> PlanningGroupSelection: + """Build a selection, rejecting overlapping selected global joints.""" + seen_joints: dict[GlobalJointName, PlanningGroupID] = {} + joint_names: list[GlobalJointName] = [] + robot_names: list[RobotName] = [] + for group in groups: + if group.robot_name not in robot_names: + robot_names.append(group.robot_name) + for joint_name in group.joint_names: + previous_group_id = seen_joints.get(joint_name) + if previous_group_id is not None: + raise ValueError( + "Selected planning groups overlap on global joint " + f"{joint_name}: {previous_group_id} and {group.id}" + ) + seen_joints[joint_name] = group.id + joint_names.append(joint_name) + + return cls( + groups=groups, + group_ids=tuple(group.id for group in groups), + joint_names=tuple(joint_names), + robot_names=tuple(robot_names), + ) diff --git a/dimos/manipulation/planning/groups/registry.py b/dimos/manipulation/planning/groups/registry.py new file mode 100644 index 0000000000..608ac8c385 --- /dev/null +++ b/dimos/manipulation/planning/groups/registry.py @@ -0,0 +1,103 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Backend-independent planning-group registry.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +from dimos.manipulation.planning.groups.discovery import FALLBACK_PLANNING_GROUP_NAME +from dimos.manipulation.planning.groups.identifiers import ( + make_global_joint_names, + make_planning_group_id, +) +from dimos.manipulation.planning.groups.models import PlanningGroup, PlanningGroupSelection +from dimos.manipulation.planning.spec.models import PlanningGroupID, RobotName + +if TYPE_CHECKING: + from dimos.manipulation.planning.spec.config import RobotModelConfig + + +class PlanningGroupRegistry: + """Registry of public planning groups derived from robot configs.""" + + def __init__(self, robot_configs: Iterable[RobotModelConfig] = ()) -> None: + self._groups: dict[PlanningGroupID, PlanningGroup] = {} + self._groups_by_robot: dict[RobotName, list[PlanningGroup]] = {} + for config in robot_configs: + self.add_robot(config) + + def add_robot(self, config: RobotModelConfig) -> None: + """Register all planning groups declared by one robot config.""" + if config.name in self._groups_by_robot: + raise ValueError(f"Robot '{config.name}' is already registered") + + robot_groups: list[PlanningGroup] = [] + for definition in config.planning_groups: + group_id = make_planning_group_id(config.name, definition.name) + if group_id in self._groups: + raise ValueError(f"Planning group '{group_id}' is already registered") + group = PlanningGroup( + id=group_id, + robot_name=config.name, + group_name=definition.name, + joint_names=tuple(make_global_joint_names(config.name, definition.joint_names)), + local_joint_names=definition.joint_names, + base_link=definition.base_link, + tip_link=definition.tip_link, + source=definition.source, + ) + self._groups[group_id] = group + robot_groups.append(group) + self._groups_by_robot[config.name] = robot_groups + + def list(self) -> tuple[PlanningGroup, ...]: + """List planning groups in robot registration order.""" + groups: list[PlanningGroup] = [] + for robot_groups in self._groups_by_robot.values(): + groups.extend(robot_groups) + return tuple(groups) + + def get(self, group_id: PlanningGroupID) -> PlanningGroup: + """Return one planning group by public ID.""" + try: + return self._groups[group_id] + except KeyError as exc: + raise KeyError(f"Unknown planning group ID: {group_id}") from exc + + def select(self, group_ids: Iterable[PlanningGroupID]) -> PlanningGroupSelection: + """Validate and return an ordered planning-group selection.""" + return PlanningGroupSelection.from_groups( + tuple(self.get(group_id) for group_id in group_ids) + ) + + def groups_for_robot(self, robot_name: RobotName) -> tuple[PlanningGroup, ...]: + """Return planning groups for one robot.""" + return tuple(self._groups_by_robot.get(robot_name, ())) + + def default_group_id_for_robot(self, robot_name: RobotName) -> PlanningGroupID | None: + """Return the generated fallback group ID for robot-scoped wrappers.""" + group_id = make_planning_group_id(robot_name, FALLBACK_PLANNING_GROUP_NAME) + return group_id if group_id in self._groups else None + + def primary_pose_group_id_for_robot(self, robot_name: RobotName) -> PlanningGroupID | None: + """Return the first pose-targetable group ID for compatibility paths.""" + # TODO: Replace this compatibility selection with either one TF publication per + # pose-targetable planning group or backend-level whole-robot TF publishing. + for group in self.groups_for_robot(robot_name): + if group.has_pose_target: + return group.id + return None diff --git a/dimos/manipulation/planning/kinematics/drake_optimization_ik.py b/dimos/manipulation/planning/kinematics/drake_optimization_ik.py index abccee119d..5e9679e656 100644 --- a/dimos/manipulation/planning/kinematics/drake_optimization_ik.py +++ b/dimos/manipulation/planning/kinematics/drake_optimization_ik.py @@ -16,10 +16,12 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING import numpy as np +from dimos.manipulation.planning.groups.models import PlanningGroup from dimos.manipulation.planning.spec.enums import IKStatus from dimos.manipulation.planning.spec.models import IKResult, WorldRobotID from dimos.manipulation.planning.spec.protocols import WorldSpec @@ -76,10 +78,9 @@ def solve( seed: JointState | None = None, position_tolerance: float = 0.001, orientation_tolerance: float = 0.01, - check_collision: bool = True, max_attempts: int = 10, ) -> IKResult: - """Solve IK with multiple random restarts, returning the best collision-free solution.""" + """Solve IK with multiple random restarts, returning the best solution.""" error = self._validate_world(world) if error is not None: return error @@ -130,11 +131,6 @@ def solve( ) if result.is_success() and result.joint_state is not None: - # Check collision if requested - if check_collision: - if not world.check_config_collision_free(robot_id, result.joint_state): - continue # Try another seed - # Check error total_error = result.position_error + result.orientation_error if total_error < best_error: @@ -156,6 +152,44 @@ def solve( f"IK failed after {max_attempts} attempts", ) + def solve_pose_targets( + self, + world: WorldSpec, + pose_targets: Mapping[PlanningGroup, PoseStamped], + auxiliary_groups: Sequence[PlanningGroup] = (), + seed: JointState | None = None, + position_tolerance: float = 0.001, + orientation_tolerance: float = 0.01, + max_attempts: int = 10, + ) -> IKResult: + """Solve a single planning-group pose target for protocol compatibility.""" + if auxiliary_groups: + return _create_failure_result( + IKStatus.NO_SOLUTION, + "DrakeOptimizationIK does not support auxiliary planning groups", + ) + if len(pose_targets) != 1: + return _create_failure_result( + IKStatus.NO_SOLUTION, + "DrakeOptimizationIK supports exactly one pose target", + ) + group, target_pose = next(iter(pose_targets.items())) + robot_id = _robot_id_for_name(world, group.robot_name) + if robot_id is None: + return _create_failure_result( + IKStatus.NO_SOLUTION, + f"No robot named '{group.robot_name}'", + ) + return self.solve( + world=world, + robot_id=robot_id, + target_pose=target_pose, + seed=seed, + position_tolerance=position_tolerance, + orientation_tolerance=orientation_tolerance, + max_attempts=max_attempts, + ) + def _solve_single( self, world: WorldSpec, @@ -259,6 +293,13 @@ def _create_success_result( ) +def _robot_id_for_name(world: WorldSpec, robot_name: str) -> WorldRobotID | None: + for robot_id in world.get_robot_ids(): + if world.get_robot_config(robot_id).name == robot_name: + return robot_id + return None + + def _create_failure_result( status: IKStatus, message: str, diff --git a/dimos/manipulation/planning/kinematics/jacobian_ik.py b/dimos/manipulation/planning/kinematics/jacobian_ik.py index 7727b6fa0f..e250fd26d1 100644 --- a/dimos/manipulation/planning/kinematics/jacobian_ik.py +++ b/dimos/manipulation/planning/kinematics/jacobian_ik.py @@ -24,12 +24,20 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING +import warnings import numpy as np +from dimos.manipulation.planning.groups.joints import filter_joint_state_to_selected_joints +from dimos.manipulation.planning.groups.models import PlanningGroup, PlanningGroupSelection from dimos.manipulation.planning.spec.enums import IKStatus -from dimos.manipulation.planning.spec.models import IKResult, WorldRobotID +from dimos.manipulation.planning.spec.models import ( + IKResult, + RobotName, + WorldRobotID, +) from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.kinematics_utils import ( check_singularity, @@ -53,7 +61,11 @@ class JacobianIK: - """Backend-agnostic Jacobian-based IK solver. + """Deprecated backend-agnostic Jacobian-based IK solver. + + Prefer PinkIK or DrakeOptimizationIK for new planning-group-aware code. + This class is retained as a compatibility/smoke-test backend and only + supports one directly pose-targeted planning group with no auxiliary groups. This class provides iterative and differential IK methods using only the standard WorldSpec interface. It works with any physics backend @@ -89,6 +101,11 @@ def __init__( max_iterations: Default maximum iterations for iterative IK singularity_threshold: Manipulability threshold for singularity detection """ + warnings.warn( + "JacobianIK is deprecated; use PinkIK or DrakeOptimizationIK for new code.", + DeprecationWarning, + stacklevel=2, + ) self._damping = damping self._max_iterations = max_iterations self._singularity_threshold = singularity_threshold @@ -101,7 +118,6 @@ def solve( seed: JointState | None = None, position_tolerance: float = 0.001, orientation_tolerance: float = 0.01, - check_collision: bool = True, max_attempts: int = 10, ) -> IKResult: """Solve IK with multiple random restarts. @@ -116,7 +132,6 @@ def solve( seed: Initial guess (uses current state if None) position_tolerance: Required position accuracy (meters) orientation_tolerance: Required orientation accuracy (radians) - check_collision: Whether to check collision of solution max_attempts: Maximum random restart attempts Returns: @@ -159,11 +174,6 @@ def solve( ) if result.is_success() and result.joint_state is not None: - # Check collision if requested - if check_collision: - if not world.check_config_collision_free(robot_id, result.joint_state): - continue # Try another seed - # Check error total_error = result.position_error + result.orientation_error if total_error < best_error: @@ -185,6 +195,82 @@ def solve( f"IK failed after {max_attempts} attempts", ) + def solve_pose_targets( + self, + world: WorldSpec, + pose_targets: Mapping[PlanningGroup, PoseStamped], + auxiliary_groups: Sequence[PlanningGroup] = (), + seed: JointState | None = None, + position_tolerance: float = 0.001, + orientation_tolerance: float = 0.01, + max_attempts: int = 10, + ) -> IKResult: + """Solve pose targets keyed by planning group with request-scoped auxiliaries. + + This backend currently supports exactly one directly pose-targeted planning group. + """ + if not pose_targets: + return _create_failure_result( + IKStatus.NO_SOLUTION, "At least one pose target is required" + ) + + pose_groups = tuple(pose_targets.keys()) + if len(pose_groups) != 1 or auxiliary_groups: + return _create_failure_result( + IKStatus.NO_SOLUTION, + "JacobianIK supports exactly one pose target and no auxiliary planning groups", + ) + + try: + selection = PlanningGroupSelection.from_groups(pose_groups + tuple(auxiliary_groups)) + robot_ids_by_name = _robot_ids_by_name(world, selection.robot_names) + except (KeyError, ValueError) as exc: + return _create_failure_result(IKStatus.NO_SOLUTION, str(exc)) + + target_group = pose_groups[0] + if not target_group.has_pose_target: + return _create_failure_result( + IKStatus.NO_SOLUTION, + f"Planning group '{target_group.id}' has no pose target frame", + ) + + robot_ids = {robot_ids_by_name[group.robot_name] for group in selection.groups} + if len(robot_ids) != 1: + return _create_failure_result( + IKStatus.NO_SOLUTION, + "JacobianIK does not support cross-robot pose IK", + ) + + robot_id = robot_ids_by_name[target_group.robot_name] + full_seed = seed + if full_seed is None: + with world.scratch_context() as ctx: + full_seed = world.get_joint_state(ctx, robot_id) + + target_pose = pose_targets[next(iter(pose_targets.keys()))] + result = self.solve( + world=world, + robot_id=robot_id, + target_pose=target_pose, + seed=full_seed, + position_tolerance=position_tolerance, + orientation_tolerance=orientation_tolerance, + max_attempts=max_attempts, + ) + if not result.is_success() or result.joint_state is None: + return result + + selected_joint_names: list[str] = [] + for group in selection.groups: + selected_joint_names.extend(group.joint_names) + try: + result.joint_state = filter_joint_state_to_selected_joints( + result.joint_state, selected_joint_names + ) + except ValueError as exc: + return _create_failure_result(IKStatus.NO_SOLUTION, str(exc)) + return result + def solve_iterative( self, world: WorldSpec, @@ -433,3 +519,21 @@ def _create_failure_result( iterations=iterations, message=message, ) + + +def _robot_ids_by_name( + world: WorldSpec, robot_names: tuple[RobotName, ...] +) -> dict[RobotName, WorldRobotID]: + robot_ids_by_name: dict[RobotName, WorldRobotID] = {} + for robot_name in robot_names: + matches = [ + robot_id + for robot_id in world.get_robot_ids() + if world.get_robot_config(robot_id).name == robot_name + ] + if not matches: + raise KeyError(f"Robot '{robot_name}' not found") + if len(matches) > 1: + raise ValueError(f"Robot name '{robot_name}' is not unique in planning world") + robot_ids_by_name[robot_name] = matches[0] + return robot_ids_by_name diff --git a/dimos/manipulation/planning/kinematics/pink_ik.py b/dimos/manipulation/planning/kinematics/pink_ik.py index 1245e5aea4..f5169d140b 100644 --- a/dimos/manipulation/planning/kinematics/pink_ik.py +++ b/dimos/manipulation/planning/kinematics/pink_ik.py @@ -16,18 +16,25 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from dataclasses import dataclass -import importlib from pathlib import Path -from types import ModuleType from typing import TYPE_CHECKING, Any import numpy as np +from dimos.manipulation.planning.groups.identifiers import make_global_joint_name +from dimos.manipulation.planning.groups.joints import matching_global_joint_name +from dimos.manipulation.planning.groups.models import PlanningGroup, PlanningGroupSelection +from dimos.manipulation.planning.groups.registry import PlanningGroupRegistry from dimos.manipulation.planning.kinematics.config import PinkKinematicsConfig from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.manipulation.planning.spec.enums import IKStatus -from dimos.manipulation.planning.spec.models import IKResult, WorldRobotID +from dimos.manipulation.planning.spec.models import ( + IKResult, + RobotName, + WorldRobotID, +) from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.kinematics_utils import compute_pose_error from dimos.manipulation.planning.utils.mesh_utils import prepare_urdf_for_drake @@ -39,6 +46,18 @@ if TYPE_CHECKING: from numpy.typing import NDArray +try: + import pink # type: ignore[import-not-found, import-untyped] + import pinocchio # type: ignore[import-not-found] + import qpsolvers # type: ignore[import-not-found] +except ImportError as exc: + pink = None # type: ignore[assignment] + pinocchio = None # type: ignore[assignment] + qpsolvers = None # type: ignore[assignment] + _PINK_IMPORT_ERROR: ImportError | None = exc +else: + _PINK_IMPORT_ERROR = None + logger = setup_logger() @@ -49,12 +68,6 @@ class PinkIKDependencyError(ImportError): PinkIKConfig = PinkKinematicsConfig -@dataclass(frozen=True) -class _PinkModules: - pink: ModuleType - pinocchio: ModuleType - - _MANIPULATION_EXTRA_HINT = "Install manipulation dependencies with: uv sync --extra manipulation." @@ -63,6 +76,15 @@ class _JointMapping: dimos_joint_names: list[str] model_joint_names: list[str] idx_q: list[int] + idx_q_array: NDArray[np.int64] + + +@dataclass +class _PinkRobotModelContext: + model: Any + mapping: _JointMapping + neutral_q: NDArray[np.float64] + frame_ids: dict[str, int] @dataclass @@ -72,6 +94,11 @@ class _PinkRobotContext: frame_id: int frame_name: str mapping: _JointMapping + neutral_q: NDArray[np.float64] | None = None + + +class _CurrentStateRequiredError(ValueError): + """Raised when normalizing a seed requires the world's current state.""" class PinkIK: @@ -97,8 +124,8 @@ def __init__( config_values = (config or PinkKinematicsConfig()).model_dump() config_values.update(overrides) self.config = PinkKinematicsConfig(**config_values) - self._modules = _load_optional_dependencies(self.config.solver) - self._robot_contexts: dict[str, _PinkRobotContext] = {} + _check_optional_dependencies(self.config.solver) + self._robot_model_contexts: dict[tuple[object, ...], _PinkRobotModelContext] = {} def solve( self, @@ -108,39 +135,177 @@ def solve( seed: JointState | None = None, position_tolerance: float = 0.001, orientation_tolerance: float = 0.01, - check_collision: bool = True, max_attempts: int = 10, ) -> IKResult: """Solve IK with Pink, returning the standard planning ``IKResult``.""" - if not world.is_finalized: - return _failure(IKStatus.NO_SOLUTION, "World must be finalized before IK") + try: + config = world.get_robot_config(robot_id) + group = _single_pose_group_for_robot(world, config.name) + except (KeyError, ValueError) as exc: + return _failure(IKStatus.NO_SOLUTION, str(exc)) + return self.solve_pose_targets( + world=world, + pose_targets={group: target_pose}, + seed=seed, + position_tolerance=position_tolerance, + orientation_tolerance=orientation_tolerance, + max_attempts=max_attempts, + ) + def solve_pose_targets( + self, + world: WorldSpec, + pose_targets: Mapping[PlanningGroup, PoseStamped], + auxiliary_groups: Sequence[PlanningGroup] = (), + seed: JointState | None = None, + position_tolerance: float = 0.001, + orientation_tolerance: float = 0.01, + max_attempts: int = 10, + ) -> IKResult: + """Solve planning-group pose targets and return selected global joints.""" + if not pose_targets: + return _failure(IKStatus.NO_SOLUTION, "At least one pose target is required") + + pose_groups = tuple(pose_targets.keys()) try: - robot_context = self._get_robot_context(world, robot_id) - except (FileNotFoundError, ImportError, ValueError) as exc: - return _failure(IKStatus.NO_SOLUTION, f"Pink IK model setup failed: {exc}") + selection = PlanningGroupSelection.from_groups(pose_groups + tuple(auxiliary_groups)) + robot_ids_by_name = _robot_ids_by_name(world, selection.robot_names) + except (KeyError, ValueError) as exc: + return _failure(IKStatus.NO_SOLUTION, str(exc)) + + groups_by_robot: dict[RobotName, list[PlanningGroup]] = {} + pose_groups_by_robot: dict[RobotName, list[PlanningGroup]] = {} + for group in selection.groups: + groups_by_robot.setdefault(group.robot_name, []).append(group) + for group in pose_groups: + if not group.has_pose_target or group.tip_link is None: + return _failure( + IKStatus.NO_SOLUTION, + f"Planning group '{group.id}' has no pose target frame", + ) + pose_groups_by_robot.setdefault(group.robot_name, []).append(group) + + selected_positions_by_name: dict[str, float] = {} + max_position_error = 0.0 + max_orientation_error = 0.0 + total_iterations = 0 + for robot_name, groups in groups_by_robot.items(): + robot_id = robot_ids_by_name[robot_name] + robot_pose_groups = pose_groups_by_robot.get(robot_name, []) + robot_pose_targets = {group: pose_targets[group] for group in robot_pose_groups} + config = world.get_robot_config(robot_id) + seed_for_robot = _seed_for_robot_with_world_fallback(world, robot_id, seed) + if robot_pose_targets: + lower_limits, upper_limits = world.get_joint_limits(robot_id) + result = self._solve_pose_targets_for_robot( + world=world, + robot_id=robot_id, + pose_targets=robot_pose_targets, + seed=seed_for_robot, + position_tolerance=position_tolerance, + orientation_tolerance=orientation_tolerance, + max_attempts=max_attempts, + config=config, + lower_limits=lower_limits, + upper_limits=upper_limits, + target_models=self._targets_in_model_frame(config, robot_pose_targets), + ) + if not result.is_success() or result.joint_state is None: + return result + else: + result = IKResult( + status=IKStatus.SUCCESS, + joint_state=seed_for_robot, + message="Auxiliary group retained seed state", + ) + joint_state = result.joint_state + if joint_state is None: + return _failure( + IKStatus.NO_SOLUTION, + f"Pink IK result for robot '{robot_name}' has no joint state", + ) - if seed is None: - with world.scratch_context() as ctx: - seed = world.get_joint_state(ctx, robot_id) + max_position_error = max(max_position_error, result.position_error) + max_orientation_error = max(max_orientation_error, result.orientation_error) + total_iterations += result.iterations + local_positions = dict(zip(joint_state.name, joint_state.position, strict=True)) + for group in groups: + for global_name, local_name in zip( + group.joint_names, group.local_joint_names, strict=True + ): + if local_name not in local_positions: + return _failure( + IKStatus.NO_SOLUTION, + f"Pink IK result for robot '{robot_name}' is missing joint '{local_name}'", + ) + selected_positions_by_name[global_name] = float(local_positions[local_name]) + + selected_positions = [selected_positions_by_name[name] for name in selection.joint_names] + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState( + {"name": list(selection.joint_names), "position": selected_positions} + ), + position_error=max_position_error, + orientation_error=max_orientation_error, + iterations=total_iterations, + message="Pink IK target set solution found", + ) - lower_limits, upper_limits = world.get_joint_limits(robot_id) - target_model = self._target_in_model_frame(world.get_robot_config(robot_id), target_pose) + def _solve_pose_targets_for_robot( + self, + world: WorldSpec, + robot_id: WorldRobotID, + pose_targets: Mapping[PlanningGroup, PoseStamped], + seed: JointState, + position_tolerance: float, + orientation_tolerance: float, + max_attempts: int, + config: RobotModelConfig | None = None, + lower_limits: NDArray[np.float64] | None = None, + upper_limits: NDArray[np.float64] | None = None, + target_models: Mapping[PlanningGroup, NDArray[np.float64]] | None = None, + ) -> IKResult: + """Solve one robot's one-or-more frame targets.""" + try: + contexts = [ + self._get_robot_context(world, robot_id, group.tip_link, config) + for group in pose_targets + if group.tip_link is not None + ] + except (FileNotFoundError, ImportError, ValueError) as exc: + return _failure(IKStatus.NO_SOLUTION, f"Pink IK model setup failed: {exc}") + config = config or world.get_robot_config(robot_id) + if lower_limits is None or upper_limits is None: + lower_limits, upper_limits = world.get_joint_limits(robot_id) + target_models_by_group = target_models or self._targets_in_model_frame(config, pose_targets) + target_model_list = [target_models_by_group[group] for group in pose_targets] fallback_result: IKResult | None = None for attempt in range(max_attempts): try: - q0 = self._initial_q(robot_context, seed, lower_limits, upper_limits, attempt) - result = self._solve_single( - robot_context=robot_context, - target_model=target_model, - seed_q=q0, - lower_limits=lower_limits, - upper_limits=upper_limits, - position_tolerance=position_tolerance, - orientation_tolerance=orientation_tolerance, - ) + q0 = self._initial_q(contexts[0], seed, lower_limits, upper_limits, attempt) + if len(contexts) == 1: + result = self._solve_single( + robot_context=contexts[0], + target_model=target_model_list[0], + seed_q=q0, + lower_limits=lower_limits, + upper_limits=upper_limits, + position_tolerance=position_tolerance, + orientation_tolerance=orientation_tolerance, + ) + else: + result = self._solve_multi_frame( + robot_contexts=contexts, + target_models=target_model_list, + seed_q=q0, + lower_limits=lower_limits, + upper_limits=upper_limits, + position_tolerance=position_tolerance, + orientation_tolerance=orientation_tolerance, + ) except ValueError as exc: return _failure(IKStatus.NO_SOLUTION, f"Pink IK mapping failed: {exc}") except Exception as exc: @@ -150,18 +315,10 @@ def solve( if fallback_result is None: fallback_result = result continue - - if check_collision and not world.check_config_collision_free( - robot_id, result.joint_state - ): - fallback_result = _collision_failure(result) - continue - return result if fallback_result is not None: return fallback_result - return _failure(IKStatus.NO_SOLUTION, f"Pink IK failed after {max_attempts} attempts") def _solve_single( @@ -174,21 +331,66 @@ def _solve_single( position_tolerance: float, orientation_tolerance: float, ) -> IKResult: - pink = self._modules.pink - pinocchio = self._modules.pinocchio - - configuration = pink.Configuration(robot_context.model, robot_context.data, seed_q.copy()) - target_se3 = _matrix_to_se3(pinocchio, target_model) - - frame_task = pink.tasks.FrameTask( - robot_context.frame_name, - position_cost=self.config.position_cost, - orientation_cost=self.config.orientation_cost, - lm_damping=self.config.lm_damping, - gain=self.config.gain, + return self._solve_frame_targets( + robot_contexts=[robot_context], + target_models=[target_model], + seed_q=seed_q, + lower_limits=lower_limits, + upper_limits=upper_limits, + position_tolerance=position_tolerance, + orientation_tolerance=orientation_tolerance, + ) + + def _solve_multi_frame( + self, + robot_contexts: Sequence[_PinkRobotContext], + target_models: Sequence[NDArray[np.float64]], + seed_q: NDArray[np.float64], + lower_limits: NDArray[np.float64], + upper_limits: NDArray[np.float64], + position_tolerance: float, + orientation_tolerance: float, + ) -> IKResult: + """Solve multiple frame tasks for one robot model.""" + return self._solve_frame_targets( + robot_contexts=robot_contexts, + target_models=target_models, + seed_q=seed_q, + lower_limits=lower_limits, + upper_limits=upper_limits, + position_tolerance=position_tolerance, + orientation_tolerance=orientation_tolerance, + ) + + def _solve_frame_targets( + self, + robot_contexts: Sequence[_PinkRobotContext], + target_models: Sequence[NDArray[np.float64]], + seed_q: NDArray[np.float64], + lower_limits: NDArray[np.float64], + upper_limits: NDArray[np.float64], + position_tolerance: float, + orientation_tolerance: float, + ) -> IKResult: + """Solve one robot model against one or more frame targets.""" + assert pink is not None + assert pinocchio is not None + primary_context = robot_contexts[0] + configuration = pink.Configuration( + primary_context.model, primary_context.data, seed_q.copy() ) - frame_task.set_target(target_se3) - tasks: list[Any] = [frame_task] + + tasks: list[Any] = [] + for context, target_model in zip(robot_contexts, target_models, strict=True): + frame_task = pink.tasks.FrameTask( + context.frame_name, + position_cost=self.config.position_cost, + orientation_cost=self.config.orientation_cost, + lm_damping=self.config.lm_damping, + gain=self.config.gain, + ) + frame_task.set_target(_matrix_to_se3(pinocchio, target_model)) + tasks.append(frame_task) if self.config.posture_cost > 0.0: posture_task = pink.tasks.PostureTask(cost=self.config.posture_cost) @@ -197,19 +399,23 @@ def _solve_single( final_position_error = float("inf") final_orientation_error = float("inf") - for iteration in range(self.config.max_iterations): - current_pose = self._current_frame_matrix(robot_context, configuration.q) - final_position_error, final_orientation_error = compute_pose_error( - current_pose, target_model - ) + position_errors: list[float] = [] + orientation_errors: list[float] = [] + current_poses = self._current_frame_matrices(robot_contexts, configuration.q) + for current_pose, target_model in zip(current_poses, target_models, strict=True): + position_error, orientation_error = compute_pose_error(current_pose, target_model) + position_errors.append(position_error) + orientation_errors.append(orientation_error) + final_position_error = max(position_errors) + final_orientation_error = max(orientation_errors) if ( final_position_error <= position_tolerance and final_orientation_error <= orientation_tolerance ): return _success( - robot_context.mapping.dimos_joint_names, - self._q_to_dimos_positions(robot_context, configuration.q), + primary_context.mapping.dimos_joint_names, + self._q_to_dimos_positions(primary_context, configuration.q), final_position_error, final_orientation_error, iteration + 1, @@ -225,7 +431,7 @@ def _solve_single( ) configuration.integrate_inplace(velocity, self.config.dt) - joint_positions = self._q_to_dimos_positions(robot_context, configuration.q) + joint_positions = self._q_to_dimos_positions(primary_context, configuration.q) if not _within_limits(joint_positions, lower_limits, upper_limits): return IKResult( status=IKStatus.JOINT_LIMITS, @@ -245,16 +451,45 @@ def _solve_single( message="Pink IK did not converge within the iteration budget", ) - def _get_robot_context(self, world: WorldSpec, robot_id: WorldRobotID) -> _PinkRobotContext: - cache_key = str(robot_id) - if cache_key not in self._robot_contexts: - self._robot_contexts[cache_key] = self._build_robot_context( - world.get_robot_config(robot_id) - ) - return self._robot_contexts[cache_key] + def _get_robot_context( + self, + world: WorldSpec, + robot_id: WorldRobotID, + frame_name: str | None = None, + config: RobotModelConfig | None = None, + ) -> _PinkRobotContext: + config = config or world.get_robot_config(robot_id) + target_frame = frame_name or config.end_effector_link + if target_frame is None: + raise ValueError(f"Robot '{robot_id}' has no end-effector frame configured") + model_context = self._get_robot_model_context(robot_id, config) + frame_id = self._frame_id_for_model_context(model_context, target_frame) + return _PinkRobotContext( + model=model_context.model, + data=model_context.model.createData(), + frame_id=frame_id, + frame_name=target_frame, + mapping=model_context.mapping, + neutral_q=model_context.neutral_q, + ) - def _build_robot_context(self, config: RobotModelConfig) -> _PinkRobotContext: - pinocchio = self._modules.pinocchio + def _get_robot_model_context( + self, robot_id: WorldRobotID, config: RobotModelConfig + ) -> _PinkRobotModelContext: + cache_key = _robot_model_cache_key(robot_id, config) + if cache_key not in self._robot_model_contexts: + self._robot_model_contexts[cache_key] = self._build_robot_model_context(config) + return self._robot_model_contexts[cache_key] + + def _frame_id_for_model_context( + self, model_context: _PinkRobotModelContext, frame_name: str + ) -> int: + if frame_name not in model_context.frame_ids: + model_context.frame_ids[frame_name] = _get_frame_id(model_context.model, frame_name) + return model_context.frame_ids[frame_name] + + def _build_robot_model_context(self, config: RobotModelConfig) -> _PinkRobotModelContext: + assert pinocchio is not None model_path = Path(config.model_path).resolve() if not model_path.exists(): raise FileNotFoundError(f"Robot model not found: {model_path}") @@ -267,18 +502,19 @@ def _build_robot_context(self, config: RobotModelConfig) -> _PinkRobotContext: package_paths=config.package_paths, xacro_args=config.xacro_args, convert_meshes=config.auto_convert_meshes, + strip_world_joint_child_link=config.base_link + if config.strip_model_world_joint + else None, ) model = pinocchio.buildModelFromUrdf(str(prepared_path)) - - data = model.createData() - frame_id = _get_frame_id(model, config.end_effector_link) + model = _lock_uncontrolled_model_joints(pinocchio, model, config) mapping = _build_joint_mapping(model, config) - return _PinkRobotContext( + neutral_q = np.asarray(pinocchio.neutral(model), dtype=np.float64) + return _PinkRobotModelContext( model=model, - data=data, - frame_id=frame_id, - frame_name=config.end_effector_link, mapping=mapping, + neutral_q=neutral_q, + frame_ids={}, ) def _initial_q( @@ -289,63 +525,70 @@ def _initial_q( upper_limits: NDArray[np.float64], attempt: int, ) -> NDArray[np.float64]: - pinocchio = self._modules.pinocchio - neutral = pinocchio.neutral(context.model) - q = np.array(neutral, dtype=np.float64) + assert pinocchio is not None + neutral = context.neutral_q + if neutral is None: + neutral = np.asarray(pinocchio.neutral(context.model), dtype=np.float64) + q = np.array(neutral, dtype=np.float64, copy=True) if attempt == 0: positions = _seed_positions_for_mapping(seed, context.mapping) else: positions = np.random.uniform(lower_limits, upper_limits) - for value, idx_q in zip(positions, context.mapping.idx_q, strict=True): - q[idx_q] = value + q[context.mapping.idx_q_array] = positions return q def _q_to_dimos_positions( self, context: _PinkRobotContext, q: NDArray[np.float64] ) -> NDArray[np.float64]: - return np.array([q[idx_q] for idx_q in context.mapping.idx_q], dtype=np.float64) - - def _current_frame_matrix( - self, context: _PinkRobotContext, q: NDArray[np.float64] - ) -> NDArray[np.float64]: - pinocchio = self._modules.pinocchio - pinocchio.forwardKinematics(context.model, context.data, q) - pinocchio.updateFramePlacements(context.model, context.data) - placement = context.data.oMf[context.frame_id] - matrix: NDArray[np.float64] = np.eye(4) - matrix[:3, :3] = np.asarray(placement.rotation, dtype=np.float64) - matrix[:3, 3] = np.asarray(placement.translation, dtype=np.float64) - return matrix + return np.asarray(q[context.mapping.idx_q_array], dtype=np.float64) + + def _current_frame_matrices( + self, contexts: Sequence[_PinkRobotContext], q: NDArray[np.float64] + ) -> list[NDArray[np.float64]]: + assert pinocchio is not None + primary_context = contexts[0] + pinocchio.forwardKinematics(primary_context.model, primary_context.data, q) + pinocchio.updateFramePlacements(primary_context.model, primary_context.data) + return [ + _placement_to_matrix(primary_context.data.oMf[context.frame_id]) for context in contexts + ] def _target_in_model_frame( self, config: RobotModelConfig, target_pose: PoseStamped + ) -> NDArray[np.float64]: + base_world_inverse = np.linalg.inv(pose_to_matrix(config.base_pose)) + return self._target_in_model_frame_with_base_inverse(target_pose, base_world_inverse) + + def _targets_in_model_frame( + self, + config: RobotModelConfig, + pose_targets: Mapping[PlanningGroup, PoseStamped], + ) -> dict[PlanningGroup, NDArray[np.float64]]: + base_world_inverse = np.linalg.inv(pose_to_matrix(config.base_pose)) + return { + group: self._target_in_model_frame_with_base_inverse(pose, base_world_inverse) + for group, pose in pose_targets.items() + } + + def _target_in_model_frame_with_base_inverse( + self, target_pose: PoseStamped, base_world_inverse: NDArray[np.float64] ) -> NDArray[np.float64]: target_world = pose_to_matrix(target_pose) - base_world = pose_to_matrix(config.base_pose) target_model: NDArray[np.float64] = np.asarray( - np.linalg.inv(base_world) @ target_world, dtype=np.float64 + base_world_inverse @ target_world, dtype=np.float64 ) return target_model -def _load_optional_dependencies(solver: str) -> _PinkModules: - pink = _import_required_module( - "pink", - "Pink IK backend requires Pink. " - f"{_MANIPULATION_EXTRA_HINT} PyPI package: pin-pink; import name: pink.", - ) - pinocchio = _import_required_module( - "pinocchio", - f"Pink IK backend requires Pinocchio (import name 'pinocchio'). {_MANIPULATION_EXTRA_HINT}", - ) - qpsolvers = _import_required_module( - "qpsolvers", - "Pink IK backend requires qpsolvers plus a QP backend such as proxqp. " - f"{_MANIPULATION_EXTRA_HINT}", - ) - +def _check_optional_dependencies(solver: str) -> None: + if _PINK_IMPORT_ERROR is not None or pink is None or pinocchio is None or qpsolvers is None: + raise PinkIKDependencyError( + "Pink IK backend requires Pink, Pinocchio, and qpsolvers plus a QP backend " + f"such as proxqp. {_MANIPULATION_EXTRA_HINT} PyPI package: pin-pink; " + "import names: pink, pinocchio, qpsolvers." + ) from _PINK_IMPORT_ERROR available_solvers = set(getattr(qpsolvers, "available_solvers", [])) if solver not in available_solvers: raise PinkIKDependencyError( @@ -355,22 +598,13 @@ def _load_optional_dependencies(solver: str) -> _PinkModules: "which includes qpsolvers[proxqp]." ) - return _PinkModules(pink=pink, pinocchio=pinocchio) - - -def _import_required_module(name: str, message: str) -> ModuleType: - try: - return importlib.import_module(name) - except ImportError as exc: - raise PinkIKDependencyError(message) from exc - def _build_joint_mapping(model: Any, config: RobotModelConfig) -> _JointMapping: idx_q: list[int] = [] model_joint_names: list[str] = [] for dimos_name in config.joint_names: - model_joint_name = config.get_urdf_joint_name(dimos_name) + model_joint_name = dimos_name joint_id = _get_joint_id(model, model_joint_name) joint = model.joints[joint_id] nq = int(getattr(joint, "nq", 1)) @@ -386,9 +620,52 @@ def _build_joint_mapping(model: Any, config: RobotModelConfig) -> _JointMapping: dimos_joint_names=list(config.joint_names), model_joint_names=model_joint_names, idx_q=idx_q, + idx_q_array=np.asarray(idx_q, dtype=np.int64), ) +def _robot_model_cache_key(robot_id: WorldRobotID, config: RobotModelConfig) -> tuple[object, ...]: + return ( + str(robot_id), + str(Path(config.model_path).resolve()), + tuple(config.joint_names), + config.base_link, + tuple(sorted((name, str(path.resolve())) for name, path in config.package_paths.items())), + tuple(sorted(config.xacro_args.items())), + config.auto_convert_meshes, + config.strip_model_world_joint, + ) + + +def _placement_to_matrix(placement: Any) -> NDArray[np.float64]: + matrix: NDArray[np.float64] = np.eye(4) + matrix[:3, :3] = np.asarray(placement.rotation, dtype=np.float64) + matrix[:3, 3] = np.asarray(placement.translation, dtype=np.float64) + return matrix + + +def _lock_uncontrolled_model_joints(pinocchio: Any, model: Any, config: RobotModelConfig) -> Any: + """Return a Pinocchio model reduced to the joints controlled by config.""" + controlled_joint_names = set(config.joint_names) + lock_joint_ids: list[int] = [] + for joint_id, model_joint_name in enumerate(model.names): + if joint_id == 0 or model_joint_name in controlled_joint_names: + continue + joint = model.joints[joint_id] + if int(getattr(joint, "nq", 1)) > 0: + lock_joint_ids.append(joint_id) + + if not lock_joint_ids: + return model + + logger.debug( + "Reducing Pink IK model '%s' by locking uncontrolled joints: %s", + config.name, + [model.names[joint_id] for joint_id in lock_joint_ids], + ) + return pinocchio.buildReducedModel(model, lock_joint_ids, pinocchio.neutral(model)) + + def _get_joint_id(model: Any, joint_name: str) -> int: if hasattr(model, "existJointName") and not model.existJointName(joint_name): raise ValueError(_missing_joint_message(model, joint_name)) @@ -429,6 +706,10 @@ def _seed_positions_for_mapping(seed: JointState, mapping: _JointMapping) -> NDA values.append(float(positions_by_name[dimos_name])) elif model_name in positions_by_name: values.append(float(positions_by_name[model_name])) + elif ( + global_name := matching_global_joint_name(positions_by_name, dimos_name) + ) is not None: + values.append(float(positions_by_name[global_name])) else: raise ValueError(f"Seed is missing joint '{dimos_name}' (URDF name '{model_name}')") return np.array(values, dtype=np.float64) @@ -440,7 +721,7 @@ def _seed_positions_for_mapping(seed: JointState, mapping: _JointMapping) -> NDA return np.array(seed.position, dtype=np.float64) -def _matrix_to_se3(pinocchio: ModuleType, matrix: NDArray[np.float64]) -> Any: +def _matrix_to_se3(pinocchio: Any, matrix: NDArray[np.float64]) -> Any: return pinocchio.SE3(matrix[:3, :3], matrix[:3, 3]) @@ -465,7 +746,7 @@ def _success( ) -> IKResult: return IKResult( status=IKStatus.SUCCESS, - joint_state=JointState(name=joint_names, position=joint_positions.tolist()), + joint_state=JointState({"name": joint_names, "position": joint_positions.tolist()}), position_error=position_error, orientation_error=orientation_error, iterations=iterations, @@ -477,12 +758,93 @@ def _failure(status: IKStatus, message: str, iterations: int = 0) -> IKResult: return IKResult(status=status, joint_state=None, iterations=iterations, message=message) -def _collision_failure(result: IKResult) -> IKResult: - return IKResult( - status=IKStatus.COLLISION, - joint_state=None, - position_error=result.position_error, - orientation_error=result.orientation_error, - iterations=result.iterations, - message="Pink IK solution rejected by collision check", - ) +def _seed_for_robot_config( + config: RobotModelConfig, + seed: JointState | None, + current_state: JointState | None = None, +) -> JointState: + """Return a full local seed state for one robot from local/global seed input.""" + if seed is None: + if current_state is None: + raise _CurrentStateRequiredError("Current joint state is required when seed is absent") + return JointState(current_state) + if not seed.name: + if len(seed.position) == len(config.joint_names): + return JointState({"name": list(config.joint_names), "position": list(seed.position)}) + raise ValueError( + f"Seed has {len(seed.position)} positions for robot '{config.name}', " + f"expected {len(config.joint_names)}" + ) + if len(seed.name) != len(seed.position): + raise ValueError(f"Seed has {len(seed.name)} names but {len(seed.position)} positions") + seed_by_name = dict(zip(seed.name, seed.position, strict=True)) + positions: list[float] = [] + missing_local_names: list[str] = [] + for local_name in config.joint_names: + global_name = make_global_joint_name(config.name, local_name) + if local_name in seed_by_name: + positions.append(float(seed_by_name[local_name])) + elif global_name in seed_by_name: + positions.append(float(seed_by_name[global_name])) + else: + positions.append(0.0) + missing_local_names.append(local_name) + if missing_local_names: + if current_state is None: + missing = ", ".join(repr(name) for name in missing_local_names) + raise _CurrentStateRequiredError( + f"Current joint state is required for missing joints: {missing}" + ) + current = current_state + current_by_name = dict(zip(current.name, current.position, strict=True)) + for index, local_name in enumerate(config.joint_names): + if local_name not in missing_local_names: + continue + if local_name not in current_by_name: + raise ValueError(f"Seed/current state is missing joint '{local_name}'") + positions[index] = float(current_by_name[local_name]) + return JointState({"name": list(config.joint_names), "position": positions}) + + +def _seed_for_robot_with_world_fallback( + world: WorldSpec, robot_id: WorldRobotID, seed: JointState | None +) -> JointState: + """Normalize a robot seed, reading world state only when the seed is incomplete.""" + config = world.get_robot_config(robot_id) + try: + return _seed_for_robot_config(config, seed) + except _CurrentStateRequiredError: + with world.scratch_context() as ctx: + current = world.get_joint_state(ctx, robot_id) + return _seed_for_robot_config(config, seed, current) + + +def _robot_ids_by_name( + world: WorldSpec, robot_names: tuple[RobotName, ...] +) -> dict[RobotName, WorldRobotID]: + robot_ids_by_name: dict[RobotName, WorldRobotID] = {} + for robot_name in robot_names: + matches = [ + robot_id + for robot_id in world.get_robot_ids() + if world.get_robot_config(robot_id).name == robot_name + ] + if not matches: + raise KeyError(f"Robot '{robot_name}' not found") + if len(matches) > 1: + raise ValueError(f"Robot name '{robot_name}' is not unique in planning world") + robot_ids_by_name[robot_name] = matches[0] + return robot_ids_by_name + + +def _single_pose_group_for_robot(world: WorldSpec, robot_name: RobotName) -> PlanningGroup: + configs = [world.get_robot_config(robot_id) for robot_id in world.get_robot_ids()] + registry = PlanningGroupRegistry(configs) + pose_groups = [ + group for group in registry.groups_for_robot(robot_name) if group.has_pose_target + ] + if len(pose_groups) != 1: + raise ValueError( + f"Robot '{robot_name}' has {len(pose_groups)} pose-targetable planning groups" + ) + return pose_groups[0] diff --git a/dimos/manipulation/planning/kinematics/test_jacobian_ik_selection.py b/dimos/manipulation/planning/kinematics/test_jacobian_ik_selection.py new file mode 100644 index 0000000000..a850b12aab --- /dev/null +++ b/dimos/manipulation/planning/kinematics/test_jacobian_ik_selection.py @@ -0,0 +1,138 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Jacobian IK selected planning group result contracts.""" + +from __future__ import annotations + +from collections.abc import Mapping +from pathlib import Path +from typing import cast + +from dimos.manipulation.planning.groups.models import PlanningGroup +from dimos.manipulation.planning.kinematics.jacobian_ik import JacobianIK +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.manipulation.planning.spec.enums import IKStatus +from dimos.manipulation.planning.spec.models import IKResult +from dimos.manipulation.planning.spec.protocols import WorldSpec +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState + + +def _pose() -> PoseStamped: + return PoseStamped(position=[0, 0, 0], orientation=[0, 0, 0, 1]) + + +def _joint_state(names: list[str], positions: list[float]) -> JointState: + return JointState({"name": names, "position": positions}) + + +def _group( + group_id: str, joint_names: tuple[str, ...], tip_link: str | None = "tool0" +) -> PlanningGroup: + return PlanningGroup( + id=group_id, + robot_name="arm", + group_name=group_id.split("/", maxsplit=1)[1], + joint_names=joint_names, + local_joint_names=tuple(name.split("/", maxsplit=1)[1] for name in joint_names), + base_link="base_link", + tip_link=tip_link, + ) + + +class _IKWorld: + def __init__(self, groups: Mapping[str, PlanningGroup]) -> None: + self._groups = groups + self._robot_configs = { + "robot_1": RobotModelConfig( + name="arm", + model_path=Path("robot.urdf"), + base_pose=_pose(), + joint_names=["joint1", "joint2", "gripper"], + end_effector_link="tool0", + ) + } + + def get_robot_ids(self) -> list[str]: + return list(self._robot_configs) + + def get_robot_config(self, robot_id: str) -> RobotModelConfig: + return self._robot_configs[robot_id] + + +class _SuccessfulIK(JacobianIK): + def solve( + self, + world: WorldSpec, + robot_id: str, + target_pose: PoseStamped, + seed: JointState | None = None, + position_tolerance: float = 0.001, + orientation_tolerance: float = 0.01, + check_collision: bool = True, + max_attempts: int = 10, + ) -> IKResult: + return IKResult( + status=IKStatus.SUCCESS, + joint_state=_joint_state( + ["arm/joint1", "arm/joint2", "arm/gripper", "arm/unrelated"], + [0.1, 0.2, 0.3, 0.4], + ), + ) + + +def test_solve_pose_targets_filters_result_to_single_group_joints() -> None: + world = _IKWorld( + { + "arm/arm": _group("arm/arm", ("arm/joint1", "arm/joint2")), + } + ) + + result = _SuccessfulIK().solve_pose_targets( + world=cast("WorldSpec", world), + pose_targets={world._groups["arm/arm"]: _pose()}, + seed=_joint_state(["arm/joint1", "arm/joint2", "arm/gripper"], [0.0, 0.0, 0.0]), + ) + + assert result.status == IKStatus.SUCCESS + assert result.joint_state is not None + assert result.joint_state.name == ["arm/joint1", "arm/joint2"] + assert result.joint_state.position == [0.1, 0.2] + + +def test_solve_pose_targets_rejects_auxiliary_groups() -> None: + world = _IKWorld({"arm/arm": _group("arm/arm", ("arm/joint1", "arm/joint2"))}) + + result = _SuccessfulIK().solve_pose_targets( + world=cast("WorldSpec", world), + pose_targets={world._groups["arm/arm"]: _pose()}, + auxiliary_groups=[_group("arm/gripper", ("arm/gripper",))], + seed=_joint_state(["arm/joint1", "arm/joint2"], [0.0, 0.0]), + ) + + assert result.status == IKStatus.NO_SOLUTION + assert "no auxiliary planning groups" in result.message + + +def test_solve_pose_targets_rejects_group_without_pose_target_frame() -> None: + world = _IKWorld({"arm/gripper": _group("arm/gripper", ("arm/gripper",), tip_link=None)}) + + result = JacobianIK().solve_pose_targets( + world=cast("WorldSpec", world), + pose_targets={world._groups["arm/gripper"]: _pose()}, + ) + + assert result.status == IKStatus.NO_SOLUTION + assert "no pose target frame" in result.message diff --git a/dimos/manipulation/planning/kinematics/test_pink_ik.py b/dimos/manipulation/planning/kinematics/test_pink_ik.py index 2c2fd863e2..df1f19738f 100644 --- a/dimos/manipulation/planning/kinematics/test_pink_ik.py +++ b/dimos/manipulation/planning/kinematics/test_pink_ik.py @@ -25,14 +25,18 @@ import pytest from dimos.manipulation.planning.factory import create_kinematics +from dimos.manipulation.planning.groups.models import PlanningGroup +from dimos.manipulation.planning.kinematics import pink_ik as pink_ik_module from dimos.manipulation.planning.kinematics.config import PinkKinematicsConfig from dimos.manipulation.planning.kinematics.pink_ik import ( PinkIK, PinkIKConfig, PinkIKDependencyError, _build_joint_mapping, - _PinkModules, + _lock_uncontrolled_model_joints, _PinkRobotContext, + _PinkRobotModelContext, + _seed_for_robot_config, _seed_positions_for_mapping, ) from dimos.manipulation.planning.spec.config import RobotModelConfig @@ -64,7 +68,7 @@ def __init__(self, translation: np.ndarray) -> None: class _FakeData: def __init__(self) -> None: self.q = np.zeros(3) - self.oMf = [_FakePlacement(np.zeros(3))] + self.oMf = [_FakePlacement(np.zeros(3)), _FakePlacement(np.zeros(3))] class _FakeModel: @@ -73,9 +77,9 @@ class _FakeModel: def __init__(self) -> None: self.names = ["universe", "joint_b", "joint_a", "joint_c"] self.joints = [SimpleNamespace(idx_q=-1, nq=0), _FakeJoint(0), _FakeJoint(1), _FakeJoint(2)] - self.frames = [_FakeFrame("tool")] + self.frames = [_FakeFrame("tool"), _FakeFrame("wrist_tool")] self._joint_ids = {"joint_b": 1, "joint_a": 2, "joint_c": 3} - self._frame_ids = {"tool": 0} + self._frame_ids = {"tool": 0, "wrist_tool": 1} def createData(self) -> _FakeData: return _FakeData() @@ -126,7 +130,7 @@ def set_target_from_configuration(self, configuration: _FakeConfiguration) -> No self.target = configuration.q.copy() -def _fake_modules(converge: bool = True) -> _PinkModules: +def _fake_modules(converge: bool = True) -> tuple[ModuleType, ModuleType]: pinocchio = ModuleType("pinocchio") pinocchio.SE3 = _FakeSE3 # type: ignore[attr-defined] pinocchio.neutral = lambda model: np.zeros(model.nq) # type: ignore[attr-defined] @@ -136,6 +140,7 @@ def forward_kinematics(model: _FakeModel, data: _FakeData, q: np.ndarray) -> Non def update_frame_placements(model: _FakeModel, data: _FakeData) -> None: data.oMf[0] = _FakePlacement(data.q.copy()) + data.oMf[1] = _FakePlacement(data.q.copy()) pinocchio.forwardKinematics = forward_kinematics # type: ignore[attr-defined] pinocchio.updateFramePlacements = update_frame_placements # type: ignore[attr-defined] @@ -158,7 +163,15 @@ def solve_ik( pink.solve_ik = solve_ik # type: ignore[attr-defined] - return _PinkModules(pink=pink, pinocchio=pinocchio) + return pink, pinocchio + + +def _install_fake_modules(converge: bool = True) -> None: + pink, pinocchio = _fake_modules(converge=converge) + pink_ik_module.pink = pink + pink_ik_module.pinocchio = pinocchio + pink_ik_module.qpsolvers = SimpleNamespace(available_solvers=["proxqp"]) + pink_ik_module._PINK_IMPORT_ERROR = None def _robot_config() -> RobotModelConfig: @@ -172,11 +185,20 @@ def _robot_config() -> RobotModelConfig: ) +def _pose_stamped(x: float, y: float, z: float, yaw: float = 0.0) -> PoseStamped: + half_yaw = yaw / 2.0 + return PoseStamped( + frame_id="world", + position=Vector3(x, y, z), + orientation=Quaternion(0.0, 0.0, float(np.sin(half_yaw)), float(np.cos(half_yaw))), + ) + + class _TestPinkIK(PinkIK): def __init__(self, converge: bool = True) -> None: self.config = PinkIKConfig(max_iterations=3) - self._modules = _fake_modules(converge=converge) - self._robot_contexts = {} + _install_fake_modules(converge=converge) + self._robot_model_contexts = {} def _pink_ik(converge: bool = True) -> PinkIK: @@ -195,12 +217,53 @@ def _context() -> _PinkRobotContext: ) +def _patch_robot_contexts( + monkeypatch: pytest.MonkeyPatch, + ik: PinkIK, + contexts_by_frame: dict[str, _PinkRobotContext], +) -> None: + def fake_get_robot_context( + world: object, + robot_id: str, + frame_name: str | None = None, + config: RobotModelConfig | None = None, + ) -> _PinkRobotContext: + del world, robot_id, config + target_frame = frame_name or "tool" + return contexts_by_frame[target_frame] + + monkeypatch.setattr(ik, "_get_robot_context", fake_get_robot_context) + + class _FakeWorld: is_finalized = True def __init__(self, collision_free: bool = True) -> None: self.config = _robot_config() self.collision_free = collision_free + self.groups = { + "arm/wrist": PlanningGroup( + id="arm/wrist", + robot_name="arm", + group_name="wrist", + joint_names=("arm/joint_a", "arm/joint_b"), + local_joint_names=("joint_a", "joint_b"), + base_link="base", + tip_link="wrist_tool", + ), + "arm/gripper": PlanningGroup( + id="arm/gripper", + robot_name="arm", + group_name="gripper", + joint_names=("arm/joint_c",), + local_joint_names=("joint_c",), + base_link="base", + tip_link=None, + ), + } + + def get_robot_ids(self) -> list[str]: + return ["robot"] def get_robot_config(self, robot_id: str) -> RobotModelConfig: return self.config @@ -221,17 +284,92 @@ def check_config_collision_free(self, robot_id: str, joint_state: JointState) -> return self.collision_free +class _CountingWorld(_FakeWorld): + def __init__(self, collision_free: bool = True) -> None: + super().__init__(collision_free=collision_free) + self.scratch_calls = 0 + self.current_state_calls = 0 + self.joint_limit_calls = 0 + + def scratch_context(self) -> nullcontext[None]: + self.scratch_calls += 1 + return nullcontext(None) + + def get_joint_state(self, ctx: object, robot_id: str) -> JointState: + self.current_state_calls += 1 + return super().get_joint_state(ctx, robot_id) + + def get_joint_limits(self, robot_id: str) -> tuple[np.ndarray, np.ndarray]: + self.joint_limit_calls += 1 + return super().get_joint_limits(robot_id) + + +class _FakeMultiRobotWorld: + is_finalized = True + + def __init__(self) -> None: + self.configs = { + "left_robot": RobotModelConfig( + name="left", + model_path=Path("/tmp/left.urdf"), + base_pose=PoseStamped( + position=Vector3(), orientation=Quaternion(0.0, 0.0, 0.0, 1.0) + ), + joint_names=["joint_a", "joint_b"], + end_effector_link="tool", + base_link="base", + ), + "right_robot": RobotModelConfig( + name="right", + model_path=Path("/tmp/right.urdf"), + base_pose=PoseStamped( + position=Vector3(), orientation=Quaternion(0.0, 0.0, 0.0, 1.0) + ), + joint_names=["joint_c"], + end_effector_link="tool", + base_link="base", + ), + } + + def get_robot_ids(self) -> list[str]: + return list(self.configs) + + def get_robot_config(self, robot_id: str) -> RobotModelConfig: + return self.configs[robot_id] + + def scratch_context(self) -> nullcontext[None]: + return nullcontext(None) + + def get_joint_state(self, ctx: object, robot_id: str) -> JointState: + config = self.get_robot_config(robot_id) + return JointState(name=list(config.joint_names), position=[0.0] * len(config.joint_names)) + + def get_joint_limits(self, robot_id: str) -> tuple[np.ndarray, np.ndarray]: + config = self.get_robot_config(robot_id) + count = len(config.joint_names) + return np.full(count, -1.0), np.full(count, 1.0) + + +class _CountingMultiRobotWorld(_FakeMultiRobotWorld): + def __init__(self) -> None: + super().__init__() + self.joint_limit_calls: list[str] = [] + + def get_joint_limits(self, robot_id: str) -> tuple[np.ndarray, np.ndarray]: + self.joint_limit_calls.append(robot_id) + return super().get_joint_limits(robot_id) + + def test_create_kinematics_pink_missing_dependency_is_actionable( monkeypatch: pytest.MonkeyPatch, ) -> None: - from dimos.manipulation.planning.kinematics import pink_ik - - def fake_import_module(name: str) -> ModuleType: - if name == "pink": - raise ImportError("missing pink") - return ModuleType(name) + def missing_dependencies(_solver: str) -> object: + raise PinkIKDependencyError( + "Pink IK backend requires Pink. Install manipulation dependencies with: " + "uv sync --extra manipulation. PyPI package: pin-pink; import name: pink." + ) - monkeypatch.setattr(pink_ik.importlib, "import_module", fake_import_module) + monkeypatch.setattr(pink_ik_module, "_check_optional_dependencies", missing_dependencies) with pytest.raises(PinkIKDependencyError) as exc_info: create_kinematics("pink") @@ -242,24 +380,20 @@ def fake_import_module(name: str) -> ModuleType: def test_create_kinematics_pink_unavailable_solver_mentions_manipulation_extra( monkeypatch: pytest.MonkeyPatch, ) -> None: - from dimos.manipulation.planning.kinematics import pink_ik - - def fake_import_module(name: str) -> ModuleType: - module = ModuleType(name) - if name == "qpsolvers": - module.available_solvers = [] # type: ignore[attr-defined] - return module + def unavailable_solver(_solver: str) -> object: + raise PinkIKDependencyError( + "Pink IK solver 'proxqp' is not available from qpsolvers. " + "Install manipulation dependencies with uv sync --extra manipulation." + ) - monkeypatch.setattr(pink_ik.importlib, "import_module", fake_import_module) + monkeypatch.setattr(pink_ik_module, "_check_optional_dependencies", unavailable_solver) with pytest.raises(PinkIKDependencyError, match="--extra manipulation"): create_kinematics("pink") def test_create_kinematics_pink_returns_backend(monkeypatch: pytest.MonkeyPatch) -> None: - from dimos.manipulation.planning.kinematics import pink_ik - - monkeypatch.setattr(pink_ik, "_load_optional_dependencies", lambda solver: _fake_modules()) + _install_fake_modules() assert isinstance(create_kinematics("pink"), PinkIK) @@ -267,9 +401,7 @@ def test_create_kinematics_pink_returns_backend(monkeypatch: pytest.MonkeyPatch) def test_create_kinematics_pink_config_passes_tuning( monkeypatch: pytest.MonkeyPatch, ) -> None: - from dimos.manipulation.planning.kinematics import pink_ik - - monkeypatch.setattr(pink_ik, "_load_optional_dependencies", lambda solver: _fake_modules()) + _install_fake_modules() ik = create_kinematics(config=PinkKinematicsConfig(max_iterations=7, dt=0.02, posture_cost=0.0)) @@ -280,9 +412,7 @@ def test_create_kinematics_pink_config_passes_tuning( def test_pink_ik_config_overrides_are_applied(monkeypatch: pytest.MonkeyPatch) -> None: - from dimos.manipulation.planning.kinematics import pink_ik - - monkeypatch.setattr(pink_ik, "_load_optional_dependencies", lambda solver: _fake_modules()) + _install_fake_modules() ik = PinkIK(PinkIKConfig(solver="proxqp", dt=0.1), max_iterations=7, posture_cost=0.0) @@ -302,6 +432,81 @@ def test_joint_order_mapping_uses_names_not_positions() -> None: assert _seed_positions_for_mapping(seed, mapping).tolist() == [10.0, 20.0, 30.0] +def test_seed_for_robot_config_uses_complete_global_seed_without_world() -> None: + seed = JointState(name=["arm/joint_a", "arm/joint_b", "arm/joint_c"], position=[1.0, 2.0, 3.0]) + + result = _seed_for_robot_config(_robot_config(), seed) + + assert result.name == ["joint_a", "joint_b", "joint_c"] + assert result.position == [1.0, 2.0, 3.0] + + +def test_solve_pose_targets_complete_seed_does_not_read_world_state( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + context = _context() + context.frame_name = "wrist_tool" + context.frame_id = 1 + _patch_robot_contexts(monkeypatch, ik, {"wrist_tool": context}) + world = _CountingWorld(collision_free=True) + + def fake_solve_single(**_: object) -> IKResult: + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState( + name=["joint_a", "joint_b", "joint_c"], position=[0.1, 0.2, 0.3] + ), + ) + + monkeypatch.setattr(ik, "_solve_single", fake_solve_single) + + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={world.groups["arm/wrist"]: _pose_stamped(0.1, 0.0, 0.0)}, + seed=JointState( + name=["arm/joint_a", "arm/joint_b", "arm/joint_c"], position=[1.0, 2.0, 3.0] + ), + max_attempts=1, + ) + + assert result.status == IKStatus.SUCCESS + assert world.scratch_calls == 0 + assert world.current_state_calls == 0 + + +def test_solve_pose_targets_incomplete_seed_reads_world_state_once( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + context = _context() + context.frame_name = "wrist_tool" + context.frame_id = 1 + _patch_robot_contexts(monkeypatch, ik, {"wrist_tool": context}) + world = _CountingWorld(collision_free=True) + + def fake_solve_single(**_: object) -> IKResult: + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState( + name=["joint_a", "joint_b", "joint_c"], position=[0.1, 0.2, 0.3] + ), + ) + + monkeypatch.setattr(ik, "_solve_single", fake_solve_single) + + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={world.groups["arm/wrist"]: _pose_stamped(0.1, 0.0, 0.0)}, + seed=JointState(name=["arm/joint_a"], position=[1.0]), + max_attempts=1, + ) + + assert result.status == IKStatus.SUCCESS + assert world.scratch_calls == 1 + assert world.current_state_calls == 1 + + def test_mapping_failure_for_missing_joint() -> None: config = _robot_config() config.joint_names = ["joint_a", "missing", "joint_c"] @@ -310,6 +515,31 @@ def test_mapping_failure_for_missing_joint() -> None: _build_joint_mapping(_FakeModel(), config) +def test_uncontrolled_urdf_joints_are_locked_out_of_pink_model() -> None: + pinocchio = ModuleType("pinocchio") + model = _FakeModel() + model.names.append("gripper_joint") + model.joints.append(_FakeJoint(3)) + reduced_model = _FakeModel() + seen_locked_joint_ids: list[list[int]] = [] + + def build_reduced_model( + input_model: _FakeModel, locked_joint_ids: list[int], reference: np.ndarray + ) -> _FakeModel: + assert input_model is model + np.testing.assert_allclose(reference, np.zeros(model.nq)) + seen_locked_joint_ids.append(list(locked_joint_ids)) + return reduced_model + + pinocchio.neutral = lambda input_model: np.zeros(input_model.nq) # type: ignore[attr-defined] + pinocchio.buildReducedModel = build_reduced_model # type: ignore[attr-defined] + + result = _lock_uncontrolled_model_joints(pinocchio, model, _robot_config()) + + assert result is reduced_model + assert seen_locked_joint_ids == [[4]] + + def test_solve_single_returns_successful_ik_result() -> None: ik = _pink_ik(converge=True) target = np.eye(4) @@ -350,10 +580,10 @@ def test_solve_single_reports_non_convergence() -> None: assert "did not converge" in result.message -def test_solve_rejects_collision_candidate() -> None: +def test_solve_does_not_filter_collision_candidates(monkeypatch: pytest.MonkeyPatch) -> None: ik = _pink_ik(converge=True) context = _context() - ik._robot_contexts = {"robot": context} + _patch_robot_contexts(monkeypatch, ik, {"tool": context}) result = ik.solve( world=cast("Any", _FakeWorld(collision_free=False)), @@ -362,18 +592,434 @@ def test_solve_rejects_collision_candidate() -> None: position=Vector3(0.1, 0.0, 0.0), orientation=Quaternion(0.0, 0.0, 0.0, 1.0), ), - check_collision=True, max_attempts=1, ) - assert result.status == IKStatus.COLLISION - assert result.joint_state is None + assert result.status == IKStatus.SUCCESS + assert result.joint_state is not None + + +def test_solve_pose_targets_returns_selected_resolved_joints_and_group_tip( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + context = _context() + context.frame_name = "wrist_tool" + context.frame_id = 1 + _patch_robot_contexts(monkeypatch, ik, {"wrist_tool": context}) + seen_frame_names: list[str] = [] + + def fake_solve_single(**kwargs: object) -> IKResult: + robot_context = cast("_PinkRobotContext", kwargs["robot_context"]) + seen_frame_names.append(robot_context.frame_name) + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState( + {"name": ["joint_a", "joint_b", "joint_c"], "position": [0.1, 0.2, 0.3]} + ), + position_error=0.0, + orientation_error=0.0, + iterations=1, + ) + + monkeypatch.setattr(ik, "_solve_single", fake_solve_single) + + world = _FakeWorld(collision_free=True) + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={ + world.groups["arm/wrist"]: PoseStamped( + position=Vector3(0.1, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + }, + auxiliary_groups=[world.groups["arm/gripper"]], + seed=JointState( + {"name": ["arm/joint_a", "arm/joint_b", "arm/joint_c"], "position": [0.0, 0.0, 0.0]} + ), + max_attempts=1, + ) + + assert seen_frame_names == ["wrist_tool"] + assert result.status == IKStatus.SUCCESS + assert result.joint_state is not None + assert result.joint_state.name == ["arm/joint_a", "arm/joint_b", "arm/joint_c"] + assert result.joint_state.position == [0.1, 0.2, 0.3] + + +def test_solve_pose_targets_same_robot_uses_one_multi_frame_solve( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + wrist_context = _context() + wrist_context.frame_name = "wrist_tool" + wrist_context.frame_id = 1 + tool_context = _context() + _patch_robot_contexts(monkeypatch, ik, {"wrist_tool": wrist_context, "tool": tool_context}) + world = _FakeWorld(collision_free=True) + tool_group = PlanningGroup( + id="arm/tool", + robot_name="arm", + group_name="tool", + joint_names=("arm/joint_c",), + local_joint_names=("joint_c",), + base_link="base", + tip_link="tool", + ) + seen_frames: list[list[str]] = [] + + def fake_solve_multi_frame(**kwargs: object) -> IKResult: + contexts = cast("list[_PinkRobotContext]", kwargs["robot_contexts"]) + seen_frames.append([context.frame_name for context in contexts]) + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState( + {"name": ["joint_a", "joint_b", "joint_c"], "position": [0.1, 0.2, 0.3]} + ), + position_error=0.0, + orientation_error=0.0, + iterations=2, + ) + + monkeypatch.setattr(ik, "_solve_multi_frame", fake_solve_multi_frame) + + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={ + world.groups["arm/wrist"]: PoseStamped( + position=Vector3(0.1, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + tool_group: PoseStamped( + position=Vector3(0.2, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + }, + seed=JointState( + {"name": ["arm/joint_a", "arm/joint_b", "arm/joint_c"], "position": [0.0, 0.0, 0.0]} + ), + max_attempts=1, + ) + + assert seen_frames == [["wrist_tool", "tool"]] + assert result.status == IKStatus.SUCCESS + assert result.joint_state is not None + assert result.joint_state.name == ["arm/joint_a", "arm/joint_b", "arm/joint_c"] + assert result.joint_state.position == [0.1, 0.2, 0.3] + + +def test_solve_pose_targets_same_robot_builds_one_model_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + world = _FakeWorld(collision_free=True) + tool_group = PlanningGroup( + id="arm/tool", + robot_name="arm", + group_name="tool", + joint_names=("arm/joint_c",), + local_joint_names=("joint_c",), + base_link="base", + tip_link="tool", + ) + model = _FakeModel() + build_calls = 0 + + def fake_build_model_context(config: RobotModelConfig) -> _PinkRobotModelContext: + nonlocal build_calls + build_calls += 1 + mapping = _build_joint_mapping(model, config) + return _PinkRobotModelContext( + model=model, + mapping=mapping, + neutral_q=np.zeros(model.nq), + frame_ids={}, + ) + + def fake_solve_multi_frame(**_: object) -> IKResult: + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState( + name=["joint_a", "joint_b", "joint_c"], position=[0.1, 0.2, 0.3] + ), + ) + + monkeypatch.setattr(ik, "_build_robot_model_context", fake_build_model_context) + monkeypatch.setattr(ik, "_solve_multi_frame", fake_solve_multi_frame) + + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={ + world.groups["arm/wrist"]: _pose_stamped(0.1, 0.0, 0.0), + tool_group: _pose_stamped(0.2, 0.0, 0.0), + }, + seed=JointState( + name=["arm/joint_a", "arm/joint_b", "arm/joint_c"], position=[0.0, 0.0, 0.0] + ), + max_attempts=1, + ) + + assert result.status == IKStatus.SUCCESS + assert build_calls == 1 + + +def test_solve_multi_frame_updates_fk_once_per_iteration( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + ik.config = PinkIKConfig(max_iterations=1) + pinocchio = cast("Any", pink_ik_module.pinocchio) + original_forward_kinematics = pinocchio.forwardKinematics + forward_calls = 0 + + def counting_forward_kinematics(model: _FakeModel, data: _FakeData, q: np.ndarray) -> None: + nonlocal forward_calls + forward_calls += 1 + original_forward_kinematics(model, data, q) + + monkeypatch.setattr(pinocchio, "forwardKinematics", counting_forward_kinematics) + wrist_context = _context() + wrist_context.frame_name = "wrist_tool" + wrist_context.frame_id = 1 + tool_context = _context() + target = np.eye(4) + target[:3, 3] = [0.1, 0.0, 0.0] + + result = ik._solve_multi_frame( + robot_contexts=[wrist_context, tool_context], + target_models=[target, target], + seed_q=np.zeros(3), + lower_limits=np.array([-1.0, -1.0, -1.0]), + upper_limits=np.array([1.0, 1.0, 1.0]), + position_tolerance=0.001, + orientation_tolerance=0.01, + ) + + assert result.status == IKStatus.NO_SOLUTION + assert forward_calls == 1 + + +def test_target_in_model_frame_converts_world_pose_through_robot_base() -> None: + ik = _pink_ik(converge=True) + config = _robot_config() + config.base_pose = _pose_stamped(1.0, 2.0, 0.0, yaw=np.pi / 2.0) + target_world = _pose_stamped(0.8, 2.1, 0.3, yaw=np.pi / 2.0) + + target_model = ik._target_in_model_frame(config, target_world) + + np.testing.assert_allclose(target_model[:3, 3], [0.1, 0.2, 0.3], atol=1e-12) + np.testing.assert_allclose(target_model[:3, :3], np.eye(3), atol=1e-12) + + +def test_solve_pose_targets_passes_world_target_to_solver_in_model_frame( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + context = _context() + context.frame_name = "wrist_tool" + context.frame_id = 1 + _patch_robot_contexts(monkeypatch, ik, {"wrist_tool": context}) + world = _FakeWorld(collision_free=True) + world.config.base_pose = _pose_stamped(1.0, 2.0, 0.0, yaw=np.pi / 2.0) + seen_target_models: list[np.ndarray] = [] + + def fake_solve_single(**kwargs: object) -> IKResult: + target_model = cast("np.ndarray", kwargs["target_model"]) + seen_target_models.append(target_model) + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState( + name=["joint_a", "joint_b", "joint_c"], position=[0.1, 0.2, 0.3] + ), + position_error=0.0, + orientation_error=0.0, + iterations=1, + ) + + monkeypatch.setattr(ik, "_solve_single", fake_solve_single) + + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={world.groups["arm/wrist"]: _pose_stamped(0.8, 2.1, 0.3, yaw=np.pi / 2.0)}, + seed=JointState( + {"name": ["arm/joint_a", "arm/joint_b", "arm/joint_c"], "position": [0.0, 0.0, 0.0]} + ), + max_attempts=1, + ) + + assert result.status == IKStatus.SUCCESS + assert len(seen_target_models) == 1 + np.testing.assert_allclose(seen_target_models[0][:3, 3], [0.1, 0.2, 0.3], atol=1e-12) + np.testing.assert_allclose(seen_target_models[0][:3, :3], np.eye(3), atol=1e-12) + + +def test_solve_pose_targets_cross_robot_combines_global_joint_names( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + world = _FakeMultiRobotWorld() + left_group = PlanningGroup( + id="left/arm", + robot_name="left", + group_name="arm", + joint_names=("left/joint_a",), + local_joint_names=("joint_a",), + base_link="base", + tip_link="tool", + ) + right_group = PlanningGroup( + id="right/arm", + robot_name="right", + group_name="arm", + joint_names=("right/joint_c",), + local_joint_names=("joint_c",), + base_link="base", + tip_link="tool", + ) + seen_robot_ids: list[str] = [] + + def fake_solve_pose_targets_for_robot(**kwargs: object) -> IKResult: + robot_id = str(kwargs["robot_id"]) + seen_robot_ids.append(robot_id) + if robot_id == "left_robot": + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState(name=["joint_a", "joint_b"], position=[1.0, 9.0]), + ) + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState(name=["joint_c"], position=[2.0]), + ) + + monkeypatch.setattr(ik, "_solve_pose_targets_for_robot", fake_solve_pose_targets_for_robot) + + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={ + left_group: PoseStamped( + position=Vector3(0.1, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + right_group: PoseStamped( + position=Vector3(0.2, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + }, + seed=JointState( + name=["left/joint_a", "left/joint_b", "right/joint_c"], + position=[0.0, 0.0, 0.0], + ), + max_attempts=1, + ) + + assert seen_robot_ids == ["left_robot", "right_robot"] + assert result.status == IKStatus.SUCCESS + assert result.joint_state is not None + assert result.joint_state.name == ["left/joint_a", "right/joint_c"] + assert result.joint_state.position == [1.0, 2.0] + + +def test_solve_pose_targets_returns_first_robot_failure_before_touching_later_limits( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + world = _CountingMultiRobotWorld() + left_group = PlanningGroup( + id="left/arm", + robot_name="left", + group_name="arm", + joint_names=("left/joint_a",), + local_joint_names=("joint_a",), + base_link="base", + tip_link="tool", + ) + right_group = PlanningGroup( + id="right/arm", + robot_name="right", + group_name="arm", + joint_names=("right/joint_c",), + local_joint_names=("joint_c",), + base_link="base", + tip_link="tool", + ) + + def fail_first_robot(**_: object) -> IKResult: + return IKResult(status=IKStatus.NO_SOLUTION, joint_state=None, message="left failed") + + monkeypatch.setattr(ik, "_solve_pose_targets_for_robot", fail_first_robot) + + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={ + left_group: _pose_stamped(0.1, 0.0, 0.0), + right_group: _pose_stamped(0.2, 0.0, 0.0), + }, + seed=JointState( + name=["left/joint_a", "left/joint_b", "right/joint_c"], + position=[0.0, 0.0, 0.0], + ), + max_attempts=1, + ) + + assert result.status == IKStatus.NO_SOLUTION + assert result.message == "left failed" + assert world.joint_limit_calls == ["left_robot"] + + +def test_solve_pose_targets_auxiliary_robot_does_not_read_joint_limits( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ik = _pink_ik(converge=True) + world = _CountingMultiRobotWorld() + left_group = PlanningGroup( + id="left/arm", + robot_name="left", + group_name="arm", + joint_names=("left/joint_a",), + local_joint_names=("joint_a",), + base_link="base", + tip_link="tool", + ) + right_auxiliary = PlanningGroup( + id="right/gripper", + robot_name="right", + group_name="gripper", + joint_names=("right/joint_c",), + local_joint_names=("joint_c",), + base_link="base", + tip_link=None, + ) + + def solve_left_robot(**_: object) -> IKResult: + return IKResult( + status=IKStatus.SUCCESS, + joint_state=JointState(name=["joint_a", "joint_b"], position=[1.0, 9.0]), + ) + + monkeypatch.setattr(ik, "_solve_pose_targets_for_robot", solve_left_robot) + + result = ik.solve_pose_targets( + world=cast("Any", world), + pose_targets={left_group: _pose_stamped(0.1, 0.0, 0.0)}, + auxiliary_groups=[right_auxiliary], + seed=JointState( + name=["left/joint_a", "left/joint_b", "right/joint_c"], + position=[0.0, 0.0, 2.0], + ), + max_attempts=1, + ) + + assert result.status == IKStatus.SUCCESS + assert result.joint_state is not None + assert result.joint_state.name == ["left/joint_a", "right/joint_c"] + assert result.joint_state.position == [1.0, 2.0] + assert world.joint_limit_calls == ["left_robot"] def test_solve_retries_after_joint_limit_failure(monkeypatch: pytest.MonkeyPatch) -> None: ik = _pink_ik(converge=True) context = _context() - ik._robot_contexts = {"robot": context} + _patch_robot_contexts(monkeypatch, ik, {"tool": context}) calls = 0 def fake_solve_single(**_: object) -> IKResult: @@ -405,7 +1051,6 @@ def fake_solve_single(**_: object) -> IKResult: position=Vector3(0.1, 0.0, 0.0), orientation=Quaternion(0.0, 0.0, 0.0, 1.0), ), - check_collision=True, max_attempts=2, ) diff --git a/dimos/manipulation/planning/monitor/robot_state_monitor.py b/dimos/manipulation/planning/monitor/robot_state_monitor.py index 49c6b56366..39492a80ac 100644 --- a/dimos/manipulation/planning/monitor/robot_state_monitor.py +++ b/dimos/manipulation/planning/monitor/robot_state_monitor.py @@ -69,7 +69,6 @@ def __init__( lock: threading.RLock, robot_id: str, joint_names: list[str], - joint_name_mapping: dict[str, str] | None = None, timeout: float = 1.0, ) -> None: """Create a world state monitor. @@ -78,10 +77,7 @@ def __init__( world: WorldSpec instance to sync state to lock: Shared lock for thread-safe access robot_id: ID of the robot to monitor - joint_names: Ordered list of joint names for this robot (URDF names) - joint_name_mapping: Maps coordinator joint names to URDF joint names. - Example: {"left/joint1": "joint1"} means messages with "left/joint1" - will be mapped to URDF "joint1". If None, names must match exactly. + joint_names: Ordered list of local model joint names for this robot timeout: Timeout for waiting for initial state (seconds) """ self._world = world @@ -90,11 +86,6 @@ def __init__( self._joint_names = joint_names self._timeout = timeout - # Joint name mapping: coordinator name -> URDF name - self._joint_name_mapping = joint_name_mapping or {} - # Build reverse mapping: URDF name -> coordinator name - self._reverse_mapping = {v: k for k, v in self._joint_name_mapping.items()} - # Latest state self._latest_positions: NDArray[np.float64] | None = None self._latest_velocities: NDArray[np.float64] | None = None @@ -190,30 +181,19 @@ def on_joint_state(self, msg: JointState) -> None: def _extract_positions(self, msg: JointState) -> NDArray[np.float64] | None: """Extract positions for our joints from JointState message. - Handles joint name translation from coordinator namespace to URDF namespace. - If joint_name_mapping is set, message names are looked up via the reverse mapping. - Args: - msg: JointState message (may use coordinator joint names) + msg: Robot-scoped JointState message with local model joint names Returns: Array of joint positions or None if any joint is missing """ - # Build name->index map from message (coordinator names) name_to_idx = {name: i for i, name in enumerate(msg.name)} positions = [] - for urdf_joint_name in self._joint_names: - # Try direct match first (when no mapping or names already match) - if urdf_joint_name in name_to_idx: - idx = name_to_idx[urdf_joint_name] - else: - # Try reverse mapping: URDF name -> coordinator name -> msg index - orch_name = self._reverse_mapping.get(urdf_joint_name) - if orch_name is None or orch_name not in name_to_idx: - return None # Missing joint - idx = name_to_idx[orch_name] - + for local_joint_name in self._joint_names: + idx = name_to_idx.get(local_joint_name) + if idx is None: + return None if idx >= len(msg.position): return None # Position not available positions.append(msg.position[idx]) @@ -223,7 +203,7 @@ def _extract_positions(self, msg: JointState) -> NDArray[np.float64] | None: def _extract_velocities(self, msg: JointState) -> NDArray[np.float64] | None: """Extract velocities for our joints. - Uses same name translation as _extract_positions. + Uses the same local-name lookup as _extract_positions. """ if not msg.velocity or len(msg.velocity) == 0: return None @@ -231,17 +211,10 @@ def _extract_velocities(self, msg: JointState) -> NDArray[np.float64] | None: name_to_idx = {name: i for i, name in enumerate(msg.name)} velocities = [] - for urdf_joint_name in self._joint_names: - # Try direct match first - if urdf_joint_name in name_to_idx: - idx = name_to_idx[urdf_joint_name] - else: - # Try reverse mapping - orch_name = self._reverse_mapping.get(urdf_joint_name) - if orch_name is None or orch_name not in name_to_idx: - return None - idx = name_to_idx[orch_name] - + for local_joint_name in self._joint_names: + idx = name_to_idx.get(local_joint_name) + if idx is None: + return None if idx >= len(msg.velocity): return None velocities.append(msg.velocity[idx]) diff --git a/dimos/manipulation/planning/monitor/test_world_monitor.py b/dimos/manipulation/planning/monitor/test_world_monitor.py index d46eb73a6c..4a447baacd 100644 --- a/dimos/manipulation/planning/monitor/test_world_monitor.py +++ b/dimos/manipulation/planning/monitor/test_world_monitor.py @@ -14,31 +14,71 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path from typing import Any from dimos.manipulation.planning import factory as planning_factory from dimos.manipulation.planning.monitor import world_monitor as world_monitor_module from dimos.manipulation.planning.spec.config import RobotModelConfig -from dimos.manipulation.planning.spec.models import PlanningSceneInfo +from dimos.manipulation.planning.spec.models import ( + GeneratedPlan, + PlanningGroupID, + PlanningSceneInfo, +) from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState + + +class _VectorLike(list[float]): + def tolist(self) -> list[float]: + return list(self) + + +class _FakeStateMonitor: + def __init__(self, positions: Sequence[float], stale: bool = False) -> None: + self._positions = _VectorLike(float(position) for position in positions) + self._stale = stale + + def get_current_positions(self) -> _VectorLike: + return self._positions + + def get_current_velocities(self) -> None: + return None + + def is_state_stale(self, max_age: float) -> bool: + del max_age + return self._stale + + +class _ScratchContext: + def __enter__(self) -> str: + return "scratch" + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> bool: + return False class FakeWorld: def __init__(self) -> None: - self.calls: list[tuple[str, Any]] = [] + self.calls: list[tuple[Any, ...]] = [] + self.configs: dict[str, RobotModelConfig] = {} + self.collision_free_by_robot: dict[str, bool] = {} def add_robot(self, config): + robot_id = f"robot-{len(self.configs) + 1}" + self.configs[robot_id] = config + self.collision_free_by_robot[robot_id] = True self.calls.append(("add_robot", config)) - return "robot-1" + return robot_id def get_robot_ids(self): - return [] + return list(self.configs) def get_robot_config(self, robot_id): - return None + return self.configs[robot_id] def get_joint_limits(self, robot_id): return ([], []) @@ -66,22 +106,27 @@ def is_finalized(self): return True def get_live_context(self): + self.calls.append(("get_live_context", None)) return None def scratch_context(self): - return self + self.calls.append(("scratch_context", None)) + return _ScratchContext() def sync_from_joint_state(self, robot_id, joint_state) -> None: return None def set_joint_state(self, ctx, robot_id, joint_state) -> None: + self.calls.append(("set_joint_state", ctx, robot_id, joint_state)) return None def get_joint_state(self, ctx, robot_id): + self.calls.append(("get_joint_state", ctx, robot_id)) return None def is_collision_free(self, ctx, robot_id): - return True + self.calls.append(("is_collision_free", ctx, robot_id)) + return self.collision_free_by_robot[robot_id] def get_min_distance(self, ctx, robot_id): return 0.0 @@ -110,13 +155,13 @@ def initialize_scene(self, scene: PlanningSceneInfo) -> None: def publish_visualization(self, ctx=None) -> None: return None - def show_preview(self, robot_id) -> None: + def show_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: return None - def hide_preview(self, robot_id) -> None: + def hide_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: return None - def animate_path(self, robot_id, path, duration: float = 3.0) -> None: + def animate_plan(self, plan: GeneratedPlan, duration: float = 3.0) -> None: return None def close(self) -> None: @@ -136,13 +181,13 @@ def initialize_scene(self, scene: PlanningSceneInfo) -> None: def publish_visualization(self, ctx=None) -> None: return None - def show_preview(self, robot_id) -> None: - self.calls.append(("show_preview", robot_id)) + def show_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: + self.calls.append(("show_preview", tuple(group_ids))) - def hide_preview(self, robot_id) -> None: - self.calls.append(("hide_preview", robot_id)) + def hide_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: + self.calls.append(("hide_preview", tuple(group_ids))) - def animate_path(self, robot_id, path, duration: float = 3.0) -> None: + def animate_plan(self, plan: GeneratedPlan, duration: float = 3.0) -> None: return None def close(self) -> None: @@ -160,6 +205,17 @@ def _robot_config() -> RobotModelConfig: ) +def _robot_config_named(name: str, joint_names: list[str]) -> RobotModelConfig: + return RobotModelConfig( + name=name, + model_path=Path(f"/tmp/{name}.urdf"), + base_pose=PoseStamped(position=Vector3(), orientation=Quaternion([0, 0, 0, 1])), + joint_names=joint_names, + end_effector_link="ee", + base_link="base", + ) + + def test_world_monitor_add_robot_records_scene_without_visualization_probe() -> None: fake_world = FakeWorld() fake_viz = FakeViz() @@ -204,3 +260,70 @@ def test_create_planning_specs_wraps_existing_world(monkeypatch) -> None: assert planning_specs.world_monitor.visualization is None assert planning_specs.kinematics is fake_kinematics assert planning_specs.planner is fake_planner + + +def test_current_global_joint_state_uses_fresh_monitored_state_only() -> None: + fake_world = FakeWorld() + monitor = world_monitor_module.WorldMonitor(world=fake_world) # type: ignore[arg-type] + robot_id = monitor.add_robot(_robot_config_named("arm", ["j1", "j2"])) + monitor._state_monitors[robot_id] = _FakeStateMonitor([0.1, 0.2]) # pyright: ignore[reportPrivateUsage] + + current = monitor.current_global_joint_state(max_age=0.5) + + assert current is not None + assert current.name == ["arm/j1", "arm/j2"] + assert current.position == [0.1, 0.2] + + monitor._state_monitors[robot_id] = _FakeStateMonitor([0.1, 0.2], stale=True) # pyright: ignore[reportPrivateUsage] + fake_world.calls.clear() + + assert monitor.current_global_joint_state(max_age=0.5) is None + assert not any(call[0] in {"get_live_context", "get_joint_state"} for call in fake_world.calls) + + +def test_check_collision_fills_unmentioned_joints_in_one_world_context() -> None: + fake_world = FakeWorld() + monitor = world_monitor_module.WorldMonitor(world=fake_world) # type: ignore[arg-type] + left_id = monitor.add_robot(_robot_config_named("left", ["j1", "j2"])) + right_id = monitor.add_robot(_robot_config_named("right", ["j3"])) + monitor._state_monitors[left_id] = _FakeStateMonitor([0.0, 9.0]) # pyright: ignore[reportPrivateUsage] + monitor._state_monitors[right_id] = _FakeStateMonitor([8.0]) # pyright: ignore[reportPrivateUsage] + + result = monitor.check_collision(JointState(name=["left/j1"], position=[1.0])) + + assert result.status == "VALID" + set_joint_calls = [call for call in fake_world.calls if call[0] == "set_joint_state"] + assert len(set_joint_calls) == 2 + assert {call[1] for call in set_joint_calls} == {"scratch"} + left_state = set_joint_calls[0][3] + right_state = set_joint_calls[1][3] + assert left_state.name == ["j1", "j2"] + assert left_state.position == [1.0, 9.0] + assert right_state.name == ["j3"] + assert right_state.position == [8.0] + + +def test_check_collision_reports_expected_statuses() -> None: + fake_world = FakeWorld() + monitor = world_monitor_module.WorldMonitor(world=fake_world) # type: ignore[arg-type] + robot_id = monitor.add_robot(_robot_config_named("arm", ["j1"])) + monitor._state_monitors[robot_id] = _FakeStateMonitor([0.0]) # pyright: ignore[reportPrivateUsage] + + duplicate = monitor.check_collision(JointState(name=["arm/j1", "arm/j1"], position=[1.0, 2.0])) + assert duplicate.status == "INVALID" + + local_name = monitor.check_collision(JointState(name=["j1"], position=[1.0])) + assert local_name.status == "INVALID" + + unknown = monitor.check_collision(JointState(name=["arm/missing"], position=[1.0])) + assert unknown.status == "INVALID" + + monitor._state_monitors[robot_id] = _FakeStateMonitor([0.0], stale=True) # pyright: ignore[reportPrivateUsage] + stale = monitor.check_collision(JointState(name=["arm/j1"], position=[1.0])) + assert stale.status == "STALE_STATE" + + monitor._state_monitors[robot_id] = _FakeStateMonitor([0.0]) # pyright: ignore[reportPrivateUsage] + fake_world.collision_free_by_robot[robot_id] = False + collision = monitor.check_collision(JointState(name=["arm/j1"], position=[1.0])) + assert collision.status == "COLLISION" + assert collision.collision_free is False diff --git a/dimos/manipulation/planning/monitor/world_monitor.py b/dimos/manipulation/planning/monitor/world_monitor.py index 5e12568874..111d33ea74 100644 --- a/dimos/manipulation/planning/monitor/world_monitor.py +++ b/dimos/manipulation/planning/monitor/world_monitor.py @@ -21,16 +21,22 @@ from typing import TYPE_CHECKING, Any from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT +from dimos.manipulation.planning.groups.identifiers import ( + is_global_joint_name, + make_global_joint_names, +) +from dimos.manipulation.planning.groups.registry import PlanningGroupRegistry from dimos.manipulation.planning.monitor.robot_state_monitor import RobotStateMonitor from dimos.manipulation.planning.monitor.world_obstacle_monitor import WorldObstacleMonitor -from dimos.manipulation.planning.spec.models import PlanningSceneInfo +from dimos.manipulation.planning.spec.models import CollisionCheckResult, PlanningSceneInfo from dimos.manipulation.planning.spec.protocols import VisualizationSpec, WorldSpec from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Sequence import numpy as np from numpy.typing import NDArray @@ -38,8 +44,10 @@ from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.manipulation.planning.spec.models import ( CollisionObjectMessage, + GeneratedPlan, JointPath, Obstacle, + PlanningGroupID, WorldRobotID, ) from dimos.msgs.vision_msgs.Detection3D import Detection3D @@ -60,6 +68,8 @@ def __init__( self._visualization = visualization self._lock = threading.RLock() self._robot_joints: dict[WorldRobotID, list[str]] = {} + self._robot_ids_by_name: dict[str, WorldRobotID] = {} + self._planning_groups = PlanningGroupRegistry() self._robot_configs: dict[WorldRobotID, RobotModelConfig] = {} self._state_monitors: dict[WorldRobotID, RobotStateMonitor] = {} self._obstacle_monitor: WorldObstacleMonitor | None = None @@ -74,6 +84,10 @@ def add_robot(self, config: RobotModelConfig) -> WorldRobotID: with self._lock: robot_id = self._world.add_robot(config) self._robot_joints[robot_id] = config.joint_names + if config.name in self._robot_ids_by_name: + raise ValueError(f"Robot name '{config.name}' is already registered") + self._robot_ids_by_name[config.name] = robot_id + self._planning_groups.add_robot(config) self._robot_configs[robot_id] = config logger.info(f"Added robot '{config.name}' as '{robot_id}'") return robot_id @@ -90,6 +104,11 @@ def sync_visualization_scene(self) -> None: return visualization.initialize_scene(self.planning_scene_info()) + @property + def planning_groups(self) -> PlanningGroupRegistry: + """Backend-independent planning-group registry for added robots.""" + return self._planning_groups + def get_robot_ids(self) -> list[WorldRobotID]: """Get all robot IDs.""" with self._lock: @@ -130,7 +149,6 @@ def start_state_monitor( self, robot_id: WorldRobotID, joint_names: list[str] | None = None, - joint_name_mapping: dict[str, str] | None = None, ) -> None: """Start monitoring joint states. Uses config defaults if args are None.""" with self._lock: @@ -148,16 +166,11 @@ def start_state_monitor( else: joint_names = config.joint_names - # Get joint name mapping from config if not provided - if joint_name_mapping is None and config.joint_name_mapping: - joint_name_mapping = config.joint_name_mapping - monitor = RobotStateMonitor( world=self._world, lock=self._lock, robot_id=robot_id, joint_names=joint_names, - joint_name_mapping=joint_name_mapping, ) monitor.start() self._state_monitors[robot_id] = monitor @@ -294,6 +307,34 @@ def get_current_joint_state(self, robot_id: WorldRobotID) -> JointState | None: ctx = self._world.get_live_context() return self._world.get_joint_state(ctx, robot_id) + def current_global_joint_state(self, max_age: float = 1.0) -> JointState | None: + """Return fresh monitored state for every robot using global joint names. + + This method is intentionally stricter than ``get_current_joint_state``: + it never falls back to backend live/default state. Callers use it when + stale or absent telemetry should stop a planning-world operation. + """ + names: list[str] = [] + positions: list[float] = [] + for robot_id, config in self._robot_configs.items(): + monitor = self._state_monitors.get(robot_id) + if monitor is None or monitor.is_state_stale(max_age): + return None + current = self.get_current_joint_state(robot_id) + if current is None or len(current.name) != len(current.position): + return None + positions_by_local_name = dict(zip(current.name, current.position, strict=True)) + for local_name, global_name in zip( + config.joint_names, + make_global_joint_names(config.name, config.joint_names), + strict=True, + ): + if local_name not in positions_by_local_name: + return None + names.append(global_name) + positions.append(float(positions_by_local_name[local_name])) + return JointState(name=names, position=positions) + def get_current_velocities(self, robot_id: WorldRobotID) -> JointState | None: """Get current joint velocities as JointState. Returns None if not available.""" if robot_id in self._state_monitors: @@ -333,6 +374,94 @@ def is_state_valid(self, robot_id: WorldRobotID, joint_state: JointState) -> boo """Check if configuration is collision-free.""" return self._world.check_config_collision_free(robot_id, joint_state) + def check_collision( + self, + target_joints: JointState, + max_age: float = 1.0, + ) -> CollisionCheckResult: + """Check a partial global target against the full planning-world state.""" + if not self._world.is_finalized: + return CollisionCheckResult( + status="UNAVAILABLE", + collision_free=None, + message="Planning world is not finalized", + ) + if not target_joints.name: + return CollisionCheckResult( + status="INVALID", + collision_free=None, + message="Collision target must include global joint names", + ) + if len(target_joints.name) != len(target_joints.position): + return CollisionCheckResult( + status="INVALID", + collision_free=None, + message="Collision target name and position lengths must match", + ) + if len(set(target_joints.name)) != len(target_joints.name): + return CollisionCheckResult( + status="INVALID", + collision_free=None, + message="Collision target contains duplicate joint names", + ) + invalid_names = [name for name in target_joints.name if not is_global_joint_name(name)] + if invalid_names: + return CollisionCheckResult( + status="INVALID", + collision_free=None, + message=f"Expected global joint names; got invalid names: {invalid_names}", + ) + + current = self.current_global_joint_state(max_age=max_age) + if current is None: + return CollisionCheckResult( + status="STALE_STATE", + collision_free=None, + message="Fresh monitored planning-world joint state is unavailable", + ) + + positions_by_global_name = dict(zip(current.name, current.position, strict=True)) + unknown = [name for name in target_joints.name if name not in positions_by_global_name] + if unknown: + return CollisionCheckResult( + status="INVALID", + collision_free=None, + message=f"Unknown global joint names: {unknown}", + ) + for name, position in zip(target_joints.name, target_joints.position, strict=True): + positions_by_global_name[name] = float(position) + + try: + with self._world.scratch_context() as ctx: + for robot_id, config in self._robot_configs.items(): + global_names = make_global_joint_names(config.name, config.joint_names) + local_state = JointState( + name=list(config.joint_names), + position=[positions_by_global_name[name] for name in global_names], + ) + self._world.set_joint_state(ctx, robot_id, local_state) + collision_free = all( + self._world.is_collision_free(ctx, robot_id) for robot_id in self._robot_configs + ) + except Exception as exc: + return CollisionCheckResult( + status="UNAVAILABLE", + collision_free=None, + message=f"Collision check failed: {exc}", + ) + + if collision_free: + return CollisionCheckResult( + status="VALID", + collision_free=True, + message="Target is collision-free", + ) + return CollisionCheckResult( + status="COLLISION", + collision_free=False, + message="Target is in collision", + ) + def is_path_valid( self, robot_id: WorldRobotID, path: JointPath, step_size: float = 0.05 ) -> bool: @@ -366,16 +495,24 @@ def get_min_distance(self, robot_id: WorldRobotID) -> float: def get_ee_pose( self, robot_id: WorldRobotID, joint_state: JointState | None = None ) -> PoseStamped: - """Get end-effector pose. Uses current state if joint_state is None.""" + """Get end-effector pose for the robot's unique pose-targetable group.""" + return self.get_group_ee_pose( + self._unique_pose_group_id_for_robot(robot_id), + joint_state=joint_state, + ) + + def get_group_ee_pose( + self, group_id: PlanningGroupID, joint_state: JointState | None = None + ) -> PoseStamped: + """Get planning group target-frame pose using current state by default.""" + robot_id = self._robot_id_for_group(group_id) with self._world.scratch_context() as ctx: - # If no state provided, fetch current from state monitor if joint_state is None: joint_state = self.get_current_joint_state(robot_id) - if joint_state is not None: self._world.set_joint_state(ctx, robot_id, joint_state) - return self._world.get_ee_pose(ctx, robot_id) + return self._world.get_group_ee_pose(ctx, group_id) def get_link_pose( self, robot_id: WorldRobotID, link_name: str, joint_state: JointState | None = None @@ -387,8 +524,6 @@ def get_link_pose( link_name: Name of the link in the URDF joint_state: Joint state to use (uses current if None) """ - from dimos.msgs.geometry_msgs.Quaternion import Quaternion - with self._world.scratch_context() as ctx: if joint_state is None: joint_state = self.get_current_joint_state(robot_id) @@ -410,10 +545,44 @@ def get_link_pose( ) def get_jacobian(self, robot_id: WorldRobotID, joint_state: JointState) -> NDArray[np.float64]: - """Get 6xN Jacobian matrix.""" + """Get 6xN Jacobian for the robot's unique pose-targetable group.""" + return self.get_group_jacobian( + self._unique_pose_group_id_for_robot(robot_id), + joint_state=joint_state, + ) + + def get_group_jacobian( + self, group_id: PlanningGroupID, joint_state: JointState + ) -> NDArray[np.float64]: + """Get planning group target-frame 6xN Jacobian matrix.""" + self._planning_groups.get(group_id) + robot_id = self._robot_id_for_group(group_id) with self._world.scratch_context() as ctx: self._world.set_joint_state(ctx, robot_id, joint_state) - return self._world.get_jacobian(ctx, robot_id) + return self._world.get_group_jacobian(ctx, group_id) + + def _unique_pose_group_id_for_robot(self, robot_id: WorldRobotID) -> PlanningGroupID: + robot_name = self._world.get_robot_config(robot_id).name + pose_group_ids = [ + group.id + for group in self._planning_groups.groups_for_robot(robot_name) + if group.has_pose_target + ] + if len(pose_group_ids) != 1: + raise ValueError( + f"Robot '{robot_name}' has {len(pose_group_ids)} pose-targetable planning groups; " + "call get_group_ee_pose/get_group_jacobian with an explicit planning group ID" + ) + return pose_group_ids[0] + + def _robot_id_for_group(self, group_id: PlanningGroupID) -> WorldRobotID: + group = self._planning_groups.get(group_id) + try: + return self._robot_ids_by_name[group.robot_name] + except KeyError as exc: + raise KeyError( + f"Robot '{group.robot_name}' not found for planning group {group_id}" + ) from exc # Lifecycle @@ -442,25 +611,20 @@ def publish_visualization(self) -> None: if self._visualization is not None: self._visualization.publish_visualization() - def show_preview(self, robot_id: WorldRobotID) -> None: - """Show the preview representation for a robot if visualization is available.""" + def show_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: + """Show preview representation for planning groups if visualization is available.""" if self._visualization is not None: - self._visualization.show_preview(robot_id) + self._visualization.show_preview(group_ids) - def hide_preview(self, robot_id: WorldRobotID) -> None: - """Hide the preview representation for a robot if visualization is available.""" + def hide_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: + """Hide preview representation for planning groups if visualization is available.""" if self._visualization is not None: - self._visualization.hide_preview(robot_id) + self._visualization.hide_preview(group_ids) - def animate_path( - self, - robot_id: WorldRobotID, - path: JointPath, - duration: float = 3.0, - ) -> None: - """Animate a path if visualization is available.""" + def animate_plan(self, plan: GeneratedPlan, duration: float = 3.0) -> None: + """Animate a generated plan if visualization is available.""" if self._visualization is not None: - self._visualization.animate_path(robot_id, path, duration) + self._visualization.animate_plan(plan, duration) def start_visualization_thread(self, rate_hz: float = 10.0) -> None: """Start background thread for visualization updates at given rate.""" diff --git a/dimos/manipulation/planning/planners/rrt_planner.py b/dimos/manipulation/planning/planners/rrt_planner.py index 3ca19eb099..a100376053 100644 --- a/dimos/manipulation/planning/planners/rrt_planner.py +++ b/dimos/manipulation/planning/planners/rrt_planner.py @@ -26,8 +26,18 @@ import numpy as np +from dimos.manipulation.planning.groups.identifiers import ( + local_joint_name_from_global, + make_global_joint_names, +) +from dimos.manipulation.planning.groups.models import PlanningGroup, PlanningGroupSelection from dimos.manipulation.planning.spec.enums import PlanningStatus -from dimos.manipulation.planning.spec.models import JointPath, PlanningResult, WorldRobotID +from dimos.manipulation.planning.spec.models import ( + JointPath, + PlanningResult, + RobotName, + WorldRobotID, +) from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.path_utils import compute_path_length from dimos.msgs.sensor_msgs.JointState import JointState @@ -98,6 +108,13 @@ def plan_joint_path( if error is not None: return error + if world.check_edge_collision_free(robot_id, start, goal, self._collision_step_size): + return _create_success_result( + [start, goal], + time.time() - start_time, + 0, + ) + lower, upper = world.get_joint_limits(robot_id) start_tree = [TreeNode(config=q_start.copy())] goal_tree = [TreeNode(config=q_goal.copy())] @@ -147,6 +164,286 @@ def get_name(self) -> str: """Get planner name.""" return "RRTConnect" + def plan_selected_joint_path( + self, + world: WorldSpec, + selection: PlanningGroupSelection, + start: JointState, + goal: JointState, + timeout: float = 10.0, + ) -> PlanningResult: + """Plan a collision-free path for an explicit planning-group selection.""" + selected_joint_names = [ + joint_name for group in selection.groups for joint_name in group.joint_names + ] + exact_error = _validate_exact_joint_keys(start, selected_joint_names, "start") + if exact_error is not None: + return exact_error + exact_error = _validate_exact_joint_keys(goal, selected_joint_names, "goal") + if exact_error is not None: + return exact_error + + try: + robot_ids_by_name = _robot_ids_by_name(world, selection.robot_names) + except (KeyError, ValueError) as exc: + return _create_failure_result(PlanningStatus.INVALID_GOAL, str(exc)) + + robot_ids = set(robot_ids_by_name.values()) + if len(robot_ids) != 1: + return self._plan_multi_robot_selected_joint_path( + world=world, + groups=selection.groups, + robot_ids_by_name=robot_ids_by_name, + start=start, + goal=goal, + timeout=timeout, + ) + + robot_id = next(iter(robot_ids)) + robot_config = world.get_robot_config(robot_id) + full_global_joint_names = make_global_joint_names( + robot_config.name, robot_config.joint_names + ) + if selected_joint_names != full_global_joint_names: + return _create_failure_result( + PlanningStatus.UNSUPPORTED, + "RRTConnectPlanner currently requires the selected groups to cover " + "the robot controllable joint set exactly", + ) + + local_start = _global_joint_state_to_local( + start, + robot_config.name, + list(robot_config.joint_names), + selected_joint_names, + ) + local_goal = _global_joint_state_to_local( + goal, + robot_config.name, + list(robot_config.joint_names), + selected_joint_names, + ) + result = self.plan_joint_path( + world=world, + robot_id=robot_id, + start=local_start, + goal=local_goal, + timeout=timeout, + ) + if not result.is_success(): + return result + return PlanningResult( + status=result.status, + path=_local_path_to_global(result.path, robot_config.name, selected_joint_names), + planning_time=result.planning_time, + path_length=result.path_length, + iterations=result.iterations, + message=result.message, + timestamps=result.timestamps, + ) + + def _plan_multi_robot_selected_joint_path( + self, + world: WorldSpec, + groups: tuple[PlanningGroup, ...], + robot_ids_by_name: dict[RobotName, WorldRobotID], + start: JointState, + goal: JointState, + timeout: float, + ) -> PlanningResult: + """Plan over one coupled configuration vector for all selected robots.""" + start_time = time.time() + + if not world.is_finalized: + return _create_failure_result( + PlanningStatus.NO_SOLUTION, + "World must be finalized before planning", + ) + + selected_joint_names = [joint for group in groups for joint in group.joint_names] + q_start = np.array( + _order_joint_state(start, selected_joint_names).position, dtype=np.float64 + ) + q_goal = np.array(_order_joint_state(goal, selected_joint_names).position, dtype=np.float64) + + try: + robot_order, robot_joint_names = _validate_full_robot_groups( + world, groups, robot_ids_by_name + ) + except KeyError as exc: + return _create_failure_result(PlanningStatus.NO_SOLUTION, str(exc)) + if not robot_order: + return _create_failure_result( + PlanningStatus.INVALID_GOAL, "No planning groups selected" + ) + + unsupported = _validate_selected_groups_cover_full_robots( + world, robot_order, robot_joint_names + ) + if unsupported is not None: + return unsupported + + lower, upper = _combined_joint_limits(world, robot_order) + + if not _coupled_config_collision_free( + world, robot_order, robot_joint_names, selected_joint_names, q_start + ): + return _create_failure_result( + PlanningStatus.COLLISION_AT_START, + "Start configuration is in collision", + ) + if not _coupled_config_collision_free( + world, robot_order, robot_joint_names, selected_joint_names, q_goal + ): + return _create_failure_result( + PlanningStatus.COLLISION_AT_GOAL, + "Goal configuration is in collision", + ) + + if np.any(q_start < lower) or np.any(q_start > upper): + return _create_failure_result( + PlanningStatus.INVALID_START, + "Start configuration is outside joint limits", + ) + if np.any(q_goal < lower) or np.any(q_goal > upper): + return _create_failure_result( + PlanningStatus.INVALID_GOAL, + "Goal configuration is outside joint limits", + ) + + if _coupled_edge_collision_free( + world, + robot_order, + robot_joint_names, + selected_joint_names, + q_start, + q_goal, + self._collision_step_size, + ): + return _create_success_result( + [start, goal], + time.time() - start_time, + 0, + ) + + start_tree = [TreeNode(config=q_start.copy())] + goal_tree = [TreeNode(config=q_goal.copy())] + trees_swapped = False + + max_iterations = 5000 + for iteration in range(max_iterations): + if time.time() - start_time > timeout: + return _create_failure_result( + PlanningStatus.TIMEOUT, + f"Timeout after {iteration} iterations", + time.time() - start_time, + iteration, + ) + + sample = np.random.uniform(lower, upper) + extended = self._extend_coupled_tree( + world, + robot_order, + robot_joint_names, + start_tree, + sample, + self._step_size, + selected_joint_names, + ) + + if extended is not None: + connected = self._connect_coupled_tree( + world, + robot_order, + robot_joint_names, + goal_tree, + extended.config, + self._connect_step_size, + selected_joint_names, + ) + if connected is not None: + path = self._extract_path(extended, connected, selected_joint_names) + if trees_swapped: + path = list(reversed(path)) + path = _simplify_coupled_path( + world, + robot_order, + robot_joint_names, + path, + self._collision_step_size, + ) + return _create_success_result(path, time.time() - start_time, iteration + 1) + + start_tree, goal_tree = goal_tree, start_tree + trees_swapped = not trees_swapped + + return _create_failure_result( + PlanningStatus.NO_SOLUTION, + f"No path found after {max_iterations} iterations", + time.time() - start_time, + max_iterations, + ) + + def _extend_coupled_tree( + self, + world: WorldSpec, + robot_order: list[WorldRobotID], + robot_joint_names: dict[WorldRobotID, list[str]], + tree: list[TreeNode], + target: NDArray[np.float64], + step_size: float, + selected_joint_names: list[str], + ) -> TreeNode | None: + """Extend a tree in the coupled selected-joint configuration space.""" + nearest = min(tree, key=lambda node: float(np.linalg.norm(node.config - target))) + diff = target - nearest.config + dist = float(np.linalg.norm(diff)) + if dist <= step_size: + new_config = target.copy() + else: + new_config = nearest.config + step_size * (diff / dist) + + if _coupled_edge_collision_free( + world, + robot_order, + robot_joint_names, + selected_joint_names, + nearest.config, + new_config, + self._collision_step_size, + ): + new_node = TreeNode(config=new_config, parent=nearest) + nearest.children.append(new_node) + tree.append(new_node) + return new_node + return None + + def _connect_coupled_tree( + self, + world: WorldSpec, + robot_order: list[WorldRobotID], + robot_joint_names: dict[WorldRobotID, list[str]], + tree: list[TreeNode], + target: NDArray[np.float64], + step_size: float, + selected_joint_names: list[str], + ) -> TreeNode | None: + """Try to connect a coupled tree to a target configuration.""" + while True: + result = self._extend_coupled_tree( + world, + robot_order, + robot_joint_names, + tree, + target, + step_size, + selected_joint_names, + ) + if result is None: + return None + if float(np.linalg.norm(result.config - target)) < self._goal_tolerance: + return result + def _validate_inputs( self, world: WorldSpec, @@ -344,3 +641,251 @@ def _create_failure_result( iterations=iterations, message=message, ) + + +def _validate_full_robot_groups( + world: WorldSpec, + groups: tuple[PlanningGroup, ...], + robot_ids_by_name: dict[RobotName, WorldRobotID], +) -> tuple[list[WorldRobotID], dict[WorldRobotID, list[str]]]: + robot_order: list[WorldRobotID] = [] + robot_joint_names: dict[WorldRobotID, list[str]] = {} + known_robot_ids = set(world.get_robot_ids()) + for group in groups: + robot_id = robot_ids_by_name[group.robot_name] + if robot_id not in known_robot_ids: + raise KeyError(f"Robot '{robot_id}' not found") + if robot_id not in robot_joint_names: + robot_joint_names[robot_id] = [] + robot_order.append(robot_id) + robot_joint_names[robot_id].extend(group.joint_names) + return robot_order, robot_joint_names + + +def _robot_ids_by_name( + world: WorldSpec, robot_names: tuple[RobotName, ...] +) -> dict[RobotName, WorldRobotID]: + robot_ids_by_name: dict[RobotName, WorldRobotID] = {} + for robot_name in robot_names: + matches = [ + robot_id + for robot_id in world.get_robot_ids() + if world.get_robot_config(robot_id).name == robot_name + ] + if not matches: + raise KeyError(f"Robot '{robot_name}' not found") + if len(matches) > 1: + raise ValueError(f"Robot name '{robot_name}' is not unique in planning world") + robot_ids_by_name[robot_name] = matches[0] + return robot_ids_by_name + + +def _validate_selected_groups_cover_full_robots( + world: WorldSpec, + robot_order: list[WorldRobotID], + robot_joint_names: dict[WorldRobotID, list[str]], +) -> PlanningResult | None: + for robot_id in robot_order: + robot_config = world.get_robot_config(robot_id) + full_global_joint_names = make_global_joint_names( + robot_config.name, robot_config.joint_names + ) + if robot_joint_names[robot_id] != full_global_joint_names: + return _create_failure_result( + PlanningStatus.UNSUPPORTED, + "RRTConnectPlanner currently requires selected groups to cover " + "each affected robot's controllable joint set exactly", + ) + return None + + +def _combined_joint_limits( + world: WorldSpec, + robot_order: list[WorldRobotID], +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + lower_parts: list[NDArray[np.float64]] = [] + upper_parts: list[NDArray[np.float64]] = [] + for robot_id in robot_order: + lower, upper = world.get_joint_limits(robot_id) + lower_parts.append(lower) + upper_parts.append(upper) + return np.concatenate(lower_parts), np.concatenate(upper_parts) + + +def _robot_joint_state_from_combined( + combined_joint_names: list[str], + combined_positions: NDArray[np.float64], + robot_name: str, + robot_joint_names: list[str], +) -> JointState: + position_by_name = dict(zip(combined_joint_names, combined_positions.tolist(), strict=True)) + return JointState( + name=[local_joint_name_from_global(robot_name, name) for name in robot_joint_names], + position=[position_by_name[name] for name in robot_joint_names], + ) + + +def _global_joint_state_to_local( + joint_state: JointState, + robot_name: str, + robot_joint_names: list[str], + global_joint_names: list[str], +) -> JointState: + position_by_name = dict(zip(joint_state.name, joint_state.position, strict=True)) + local_joint_names = [ + local_joint_name_from_global(robot_name, name) for name in global_joint_names + ] + if local_joint_names != robot_joint_names: + raise ValueError("Global selected joints do not match robot joint order") + return JointState( + name=robot_joint_names, + position=[position_by_name[global_name] for global_name in global_joint_names], + ) + + +def _local_path_to_global( + path: JointPath, + robot_name: str, + global_joint_names: list[str], +) -> JointPath: + local_joint_names = [ + local_joint_name_from_global(robot_name, name) for name in global_joint_names + ] + global_path: JointPath = [] + for waypoint in path: + position_by_name = dict(zip(waypoint.name, waypoint.position, strict=True)) + global_path.append( + JointState( + name=global_joint_names, + position=[position_by_name[local_name] for local_name in local_joint_names], + ) + ) + return global_path + + +def _coupled_config_collision_free( + world: WorldSpec, + robot_order: list[WorldRobotID], + robot_joint_names: dict[WorldRobotID, list[str]], + selected_joint_names: list[str], + q: NDArray[np.float64], +) -> bool: + with world.scratch_context() as ctx: + for robot_id in robot_order: + world.set_joint_state( + ctx, + robot_id, + _robot_joint_state_from_combined( + selected_joint_names, + q, + world.get_robot_config(robot_id).name, + robot_joint_names[robot_id], + ), + ) + return all(world.is_collision_free(ctx, robot_id) for robot_id in robot_order) + + +def _coupled_edge_collision_free( + world: WorldSpec, + robot_order: list[WorldRobotID], + robot_joint_names: dict[WorldRobotID, list[str]], + selected_joint_names: list[str], + q_start: NDArray[np.float64], + q_end: NDArray[np.float64], + step_size: float, +) -> bool: + dist = float(np.linalg.norm(q_end - q_start)) + if dist < 1e-8: + return _coupled_config_collision_free( + world, + robot_order, + robot_joint_names, + selected_joint_names, + q_start, + ) + + n_steps = max(2, int(np.ceil(dist / step_size)) + 1) + with world.scratch_context() as ctx: + for i in range(n_steps): + t = i / (n_steps - 1) + q = q_start + t * (q_end - q_start) + for robot_id in robot_order: + world.set_joint_state( + ctx, + robot_id, + _robot_joint_state_from_combined( + selected_joint_names, + q, + world.get_robot_config(robot_id).name, + robot_joint_names[robot_id], + ), + ) + if not all(world.is_collision_free(ctx, robot_id) for robot_id in robot_order): + return False + return True + + +def _simplify_coupled_path( + world: WorldSpec, + robot_order: list[WorldRobotID], + robot_joint_names: dict[WorldRobotID, list[str]], + path: JointPath, + collision_step_size: float, + max_iterations: int = 100, +) -> JointPath: + if len(path) <= 2: + return path + + simplified = list(path) + selected_joint_names = list(path[0].name) + for _ in range(max_iterations): + if len(simplified) <= 2: + break + i = np.random.randint(0, len(simplified) - 2) + j = np.random.randint(i + 2, len(simplified)) + q_start = np.array(simplified[i].position, dtype=np.float64) + q_end = np.array(simplified[j].position, dtype=np.float64) + if _coupled_edge_collision_free( + world, + robot_order, + robot_joint_names, + selected_joint_names, + q_start, + q_end, + collision_step_size, + ): + simplified = simplified[: i + 1] + simplified[j:] + return simplified + + +def _validate_exact_joint_keys( + joint_state: JointState, selected_joint_names: list[str], state_name: str +) -> PlanningResult | None: + actual_names = list(joint_state.name) + expected_names = selected_joint_names + if set(actual_names) != set(expected_names): + missing = [name for name in expected_names if name not in actual_names] + extra = [name for name in actual_names if name not in expected_names] + details: list[str] = [] + if missing: + details.append(f"missing={missing}") + if extra: + details.append(f"extra={extra}") + return _create_failure_result( + PlanningStatus.INVALID_START if state_name == "start" else PlanningStatus.INVALID_GOAL, + f"{state_name} joint names must exactly match selected joints ({', '.join(details)})", + ) + if len(joint_state.position) != len(joint_state.name): + return _create_failure_result( + PlanningStatus.INVALID_START if state_name == "start" else PlanningStatus.INVALID_GOAL, + f"{state_name} joint name and position lengths must match", + ) + return None + + +def _order_joint_state(joint_state: JointState, joint_names: list[str]) -> JointState: + position_by_name = dict(zip(joint_state.name, joint_state.position, strict=False)) + return JointState( + name=joint_names, + position=[position_by_name[name] for name in joint_names], + ) diff --git a/dimos/manipulation/planning/planners/test_rrt_planner_selection.py b/dimos/manipulation/planning/planners/test_rrt_planner_selection.py new file mode 100644 index 0000000000..4436aba595 --- /dev/null +++ b/dimos/manipulation/planning/planners/test_rrt_planner_selection.py @@ -0,0 +1,247 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for selected-joint RRT planning group contracts.""" + +from __future__ import annotations + +from collections.abc import Callable +from contextlib import nullcontext +from pathlib import Path +from typing import cast + +import numpy as np + +from dimos.manipulation.planning.groups.models import PlanningGroup, PlanningGroupSelection +from dimos.manipulation.planning.planners.rrt_planner import RRTConnectPlanner +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.manipulation.planning.spec.enums import PlanningStatus +from dimos.manipulation.planning.spec.protocols import WorldSpec +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState + + +def _pose() -> PoseStamped: + return PoseStamped(position=[0, 0, 0], orientation=[0, 0, 0, 1]) + + +def _robot_config(name: str, joint_names: list[str]) -> RobotModelConfig: + return RobotModelConfig( + name=name, + model_path=Path("robot.urdf"), + base_pose=_pose(), + joint_names=joint_names, + end_effector_link="tool0", + ) + + +def _joint_state(names: list[str], positions: list[float]) -> JointState: + return JointState({"name": names, "position": positions}) + + +def _group( + group_id: str, + robot_name: str, + joint_names: tuple[str, ...], +) -> PlanningGroup: + return PlanningGroup( + id=group_id, + robot_name=robot_name, + group_name=group_id.split("/", maxsplit=1)[1], + joint_names=joint_names, + local_joint_names=tuple(name.split("/", maxsplit=1)[1] for name in joint_names), + base_link="base_link", + tip_link="tool0", + ) + + +def _selection(*groups: PlanningGroup) -> PlanningGroupSelection: + return PlanningGroupSelection.from_groups(tuple(groups)) + + +class _SelectionWorld: + is_finalized = True + + def __init__( + self, + robot_configs: dict[str, RobotModelConfig], + coupled_collision_predicate: Callable[[dict[str, JointState]], bool] | None = None, + ) -> None: + self._robot_configs = robot_configs + self._coupled_collision_predicate = coupled_collision_predicate + self.coupled_collision_checks = 0 + self.config_collision_names: list[list[str]] = [] + self.edge_collision_names: list[tuple[list[str], list[str]]] = [] + + def get_robot_config(self, robot_id: str) -> RobotModelConfig: + return self._robot_configs[robot_id] + + def get_robot_ids(self) -> list[str]: + return list(self._robot_configs) + + def check_config_collision_free(self, robot_id: str, joint_state: JointState) -> bool: + self.config_collision_names.append(list(joint_state.name)) + return True + + def get_joint_limits(self, robot_id: str) -> tuple[np.ndarray, np.ndarray]: + joint_count = len(self._robot_configs[robot_id].joint_names) + return -np.ones(joint_count), np.ones(joint_count) + + def check_edge_collision_free( + self, + robot_id: str, + start: JointState, + goal: JointState, + step_size: float, + ) -> bool: + self.edge_collision_names.append((list(start.name), list(goal.name))) + return True + + def scratch_context(self) -> nullcontext[dict[str, JointState]]: + return nullcontext({}) + + def set_joint_state( + self, ctx: dict[str, JointState], robot_id: str, joint_state: JointState + ) -> None: + assert joint_state.name == self._robot_configs[robot_id].joint_names + ctx[robot_id] = joint_state + + def is_collision_free(self, ctx: dict[str, JointState], robot_id: str) -> bool: + self.coupled_collision_checks += 1 + if self._coupled_collision_predicate is None: + return True + return self._coupled_collision_predicate(ctx) + + +def test_plan_selected_joint_path_rejects_missing_and_extra_start_names() -> None: + group = _group("arm/arm", "arm", ("arm/joint1", "arm/joint2")) + world = _SelectionWorld(robot_configs={"robot_1": _robot_config("arm", ["joint1", "joint2"])}) + + result = RRTConnectPlanner().plan_selected_joint_path( + cast("WorldSpec", world), + _selection(group), + start=_joint_state(["arm/joint1", "arm/extra"], [0.0, 0.0]), + goal=_joint_state(["arm/joint1", "arm/joint2"], [0.0, 0.0]), + ) + + assert result.status == PlanningStatus.INVALID_START + assert "missing" in result.message + assert "extra" in result.message + + +def test_plan_selected_joint_path_rejects_missing_and_extra_goal_names() -> None: + group = _group("arm/arm", "arm", ("arm/joint1", "arm/joint2")) + world = _SelectionWorld(robot_configs={"robot_1": _robot_config("arm", ["joint1", "joint2"])}) + + result = RRTConnectPlanner().plan_selected_joint_path( + cast("WorldSpec", world), + _selection(group), + start=_joint_state(["arm/joint1", "arm/joint2"], [0.0, 0.0]), + goal=_joint_state(["arm/joint1", "arm/extra"], [0.0, 0.0]), + ) + + assert result.status == PlanningStatus.INVALID_GOAL + assert "missing" in result.message + assert "extra" in result.message + + +def test_plan_selected_joint_path_plans_cross_robot_full_group_selection() -> None: + left_group = _group("left/arm", "left", ("left/joint1",)) + right_group = _group("right/arm", "right", ("right/joint1",)) + world = _SelectionWorld( + robot_configs={ + "left_robot": _robot_config("left", ["joint1"]), + "right_robot": _robot_config("right", ["joint1"]), + }, + ) + joint_state = _joint_state(["left/joint1", "right/joint1"], [0.0, 0.0]) + + result = RRTConnectPlanner().plan_selected_joint_path( + cast("WorldSpec", world), + _selection(left_group, right_group), + start=joint_state, + goal=_joint_state(["left/joint1", "right/joint1"], [0.1, -0.1]), + ) + + assert result.status == PlanningStatus.SUCCESS + assert len(result.path) == 2 + assert result.path[0].name == ["left/joint1", "right/joint1"] + assert result.path[-1].position == [0.1, -0.1] + assert world.coupled_collision_checks > 0 + + +def test_plan_selected_joint_path_converts_single_robot_backend_boundary_to_local() -> None: + group = _group("arm/manipulator", "arm", ("arm/joint1", "arm/joint2")) + world = _SelectionWorld(robot_configs={"robot_1": _robot_config("arm", ["joint1", "joint2"])}) + + result = RRTConnectPlanner().plan_selected_joint_path( + cast("WorldSpec", world), + _selection(group), + start=_joint_state(["arm/joint2", "arm/joint1"], [0.2, 0.1]), + goal=_joint_state(["arm/joint1", "arm/joint2"], [0.3, 0.4]), + ) + + assert result.status == PlanningStatus.SUCCESS + assert [waypoint.name for waypoint in result.path] == [ + ["arm/joint1", "arm/joint2"], + ["arm/joint1", "arm/joint2"], + ] + assert result.path[0].position == [0.1, 0.2] + assert result.path[-1].position == [0.3, 0.4] + assert world.config_collision_names == [["joint1", "joint2"], ["joint1", "joint2"]] + assert world.edge_collision_names == [(["joint1", "joint2"], ["joint1", "joint2"])] + + +def test_plan_selected_joint_path_rejects_cross_robot_coupled_goal_collision() -> None: + def coupled_free(ctx: dict[str, JointState]) -> bool: + if {"left_robot", "right_robot"} - set(ctx): + return True + left = ctx["left_robot"].position[0] + right = ctx["right_robot"].position[0] + return not (left > 0.04 and right > 0.04) + + left_group = _group("left/arm", "left", ("left/joint1",)) + right_group = _group("right/arm", "right", ("right/joint1",)) + world = _SelectionWorld( + robot_configs={ + "left_robot": _robot_config("left", ["joint1"]), + "right_robot": _robot_config("right", ["joint1"]), + }, + coupled_collision_predicate=coupled_free, + ) + + result = RRTConnectPlanner().plan_selected_joint_path( + cast("WorldSpec", world), + _selection(left_group, right_group), + start=_joint_state(["left/joint1", "right/joint1"], [0.0, 0.0]), + goal=_joint_state(["left/joint1", "right/joint1"], [0.1, 0.1]), + ) + + assert result.status == PlanningStatus.COLLISION_AT_GOAL + assert world.coupled_collision_checks > 0 + + +def test_plan_selected_joint_path_rejects_single_robot_subset_selection() -> None: + group = _group("arm/wrist", "arm", ("arm/joint2",)) + world = _SelectionWorld(robot_configs={"robot_1": _robot_config("arm", ["joint1", "joint2"])}) + joint_state = _joint_state(["arm/joint2"], [0.0]) + + result = RRTConnectPlanner().plan_selected_joint_path( + cast("WorldSpec", world), + _selection(group), + start=joint_state, + goal=joint_state, + ) + + assert result.status == PlanningStatus.UNSUPPORTED diff --git a/dimos/manipulation/planning/spec/config.py b/dimos/manipulation/planning/spec/config.py index 74dc3bd69b..7ecca4d191 100644 --- a/dimos/manipulation/planning/spec/config.py +++ b/dimos/manipulation/planning/spec/config.py @@ -21,6 +21,12 @@ from pydantic import Field from dimos.core.module import ModuleConfig +from dimos.manipulation.planning.groups.discovery import FALLBACK_PLANNING_GROUP_NAME +from dimos.manipulation.planning.groups.identifiers import ( + assert_local_joint_names, + assert_valid_robot_name, +) +from dimos.manipulation.planning.groups.models import PlanningGroupDefinition from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped @@ -30,10 +36,22 @@ class RobotModelConfig(ModuleConfig): Attributes: name: Human-readable robot name model_path: Path to robot model file (.urdf, .xacro, or .xml/MJCF) - base_pose: Pose of robot base in world frame (position + orientation) - joint_names: Ordered list of controlled joint names (in URDF namespace) - end_effector_link: Name of the end-effector link for FK/IK - base_link: Name of the base link (default: "base_link") + srdf_path: Optional path to SRDF file containing planning group definitions + base_pose: Compatibility placement transform used by current Drake + world loading/welding. This is the canonical world placement for + robot instances; model-authored world/base attach joints are + stripped when strip_model_world_joint is true. + joint_names: Ordered list of controllable joints in the local model + namespace. This is not a planning group. + joint_name_mapping: Optional mapping from external/coordinator joint + names to local model joint names for hardware adapters that publish + scoped joint names. + end_effector_link: Compatibility robot-scoped end-effector link used by + legacy helpers. New pose-targeted planning should use planning + group target frames instead. + base_link: Compatibility robot-scoped base link used by current Drake + weld/placement behavior. Planning groups own chain base links. + TODO: should remove package_paths: Dict mapping package names to filesystem Paths joint_limits_lower: Lower joint limits (radians) joint_limits_upper: Upper joint limits (radians) @@ -45,19 +63,20 @@ class RobotModelConfig(ModuleConfig): links may legitimately overlap (e.g., mimic joints). max_velocity: Maximum joint velocity for trajectory generation (rad/s) max_acceleration: Maximum joint acceleration for trajectory generation (rad/s^2) - joint_name_mapping: Maps coordinator joint names to URDF joint names. - Example: {"left/joint1": "joint1"} means coordinator's "left/joint1" - corresponds to URDF's "joint1". If empty, names are assumed to match. coordinator_task_name: Task name for executing trajectories via coordinator RPC. If set, trajectories can be executed via execute_trajectory() RPC. """ name: str model_path: Path - base_pose: PoseStamped + srdf_path: Path | None = None + base_pose: PoseStamped = Field(default_factory=PoseStamped) + strip_model_world_joint: bool = False joint_names: list[str] - end_effector_link: str + joint_name_mapping: dict[str, str] = Field(default_factory=dict) + end_effector_link: str | None = None base_link: str = "base_link" + planning_groups: list[PlanningGroupDefinition] = Field(default_factory=list) package_paths: dict[str, Path] = Field(default_factory=dict) joint_limits_lower: list[float] | None = None joint_limits_upper: list[float] | None = None @@ -69,7 +88,6 @@ class RobotModelConfig(ModuleConfig): max_velocity: float = 1.0 max_acceleration: float = 2.0 # Coordinator integration - joint_name_mapping: dict[str, str] = Field(default_factory=dict) coordinator_task_name: str | None = None gripper_hardware_id: str | None = None # TF publishing for extra links (e.g., camera mount) @@ -79,19 +97,17 @@ class RobotModelConfig(ModuleConfig): # Pre-grasp offset distance in meters (along approach direction) pre_grasp_offset: float = 0.10 - def get_urdf_joint_name(self, coordinator_name: str) -> str: - """Translate coordinator joint name to URDF joint name.""" - return self.joint_name_mapping.get(coordinator_name, coordinator_name) - - def get_coordinator_joint_name(self, urdf_name: str) -> str: - """Translate URDF joint name to coordinator joint name.""" - for coord_name, u_name in self.joint_name_mapping.items(): - if u_name == urdf_name: - return coord_name - return urdf_name - - def get_coordinator_joint_names(self) -> list[str]: - """Get joint names in coordinator namespace.""" - if not self.joint_name_mapping: - return self.joint_names - return [self.get_coordinator_joint_name(j) for j in self.joint_names] + def model_post_init(self, __context: object) -> None: + """Validate delimiter-based naming constraints.""" + assert_valid_robot_name(self.name) + assert_local_joint_names(self.joint_names) + if not self.planning_groups: + self.planning_groups = [ + PlanningGroupDefinition( + name=FALLBACK_PLANNING_GROUP_NAME, + joint_names=tuple(self.joint_names), + base_link=self.base_link, + tip_link=self.end_effector_link, + source="fallback", + ) + ] diff --git a/dimos/manipulation/planning/spec/enums.py b/dimos/manipulation/planning/spec/enums.py index 66a17ee199..e1c7c1a735 100644 --- a/dimos/manipulation/planning/spec/enums.py +++ b/dimos/manipulation/planning/spec/enums.py @@ -47,3 +47,4 @@ class PlanningStatus(Enum): INVALID_GOAL = auto() COLLISION_AT_START = auto() COLLISION_AT_GOAL = auto() + UNSUPPORTED = auto() diff --git a/dimos/manipulation/planning/spec/models.py b/dimos/manipulation/planning/spec/models.py index d412e9f766..f9f87e3398 100644 --- a/dimos/manipulation/planning/spec/models.py +++ b/dimos/manipulation/planning/spec/models.py @@ -18,7 +18,7 @@ from collections.abc import Mapping from dataclasses import dataclass, field -from typing import TYPE_CHECKING, TypeAlias +from typing import TYPE_CHECKING, Literal, TypeAlias from dimos.manipulation.planning.spec.enums import ( IKStatus, @@ -41,6 +41,15 @@ WorldRobotID: TypeAlias = str """Internal Drake world robot ID""" +PlanningGroupID: TypeAlias = str +"""Public planning group ID of the form {robot_name}/{group_name}.""" + +LocalModelJointName: TypeAlias = str +"""Joint name as it appears in URDF/SRDF before world binding.""" + +GlobalJointName: TypeAlias = str +"""Public joint name of the form {robot_name}/{local_joint_name}.""" + JointPath: TypeAlias = "list[JointState]" """List of joint states forming a path (each waypoint has names + positions)""" @@ -60,6 +69,62 @@ class PlanningSceneInfo: Jacobian: TypeAlias = "NDArray[np.float64]" """6 x n Jacobian matrix (rows: [vx, vy, vz, wx, wy, wz])""" +CollisionCheckStatus: TypeAlias = Literal[ + "VALID", + "COLLISION", + "INVALID", + "UNAVAILABLE", + "STALE_STATE", +] +"""Status for a planning-world collision target check.""" + +ForwardKinematicsStatus: TypeAlias = Literal[ + "VALID", + "INVALID", + "UNAVAILABLE", + "STALE_STATE", +] +"""Status for a group-scoped forward-kinematics query.""" + + +@dataclass(frozen=True) +class CollisionCheckResult: + """Result of a planning-world collision target check.""" + + status: CollisionCheckStatus + collision_free: bool | None + message: str + + +@dataclass(frozen=True) +class ForwardKinematicsResult: + """Result of a group-scoped forward-kinematics query.""" + + status: ForwardKinematicsStatus + pose: PoseStamped | None + message: str + + +@dataclass +class GeneratedPlan: + """Canonical generated planning artifact. + + The path uses global joint names and contains exactly the selected joints. + Downstream preview/execution projections are computed lazily from this data. + """ + + group_ids: tuple[PlanningGroupID, ...] + path: list[JointState] = field(default_factory=list) + status: PlanningStatus = PlanningStatus.NO_SOLUTION + planning_time: float = 0.0 + path_length: float = 0.0 + iterations: int = 0 + message: str = "" + + def is_success(self) -> bool: + """Check if planning was successful.""" + return self.status == PlanningStatus.SUCCESS + @dataclass class Obstacle: diff --git a/dimos/manipulation/planning/spec/protocols.py b/dimos/manipulation/planning/spec/protocols.py index c7ee95ee0a..3c0dca3cf5 100644 --- a/dimos/manipulation/planning/spec/protocols.py +++ b/dimos/manipulation/planning/spec/protocols.py @@ -23,16 +23,19 @@ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: + from collections.abc import Sequence from contextlib import AbstractContextManager import numpy as np from numpy.typing import NDArray + from dimos.manipulation.planning.groups.models import PlanningGroup, PlanningGroupSelection from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.manipulation.planning.spec.models import ( + GeneratedPlan, IKResult, - JointPath, Obstacle, + PlanningGroupID, PlanningResult, PlanningSceneInfo, WorldRobotID, @@ -156,8 +159,19 @@ def check_edge_collision_free( ... # Forward Kinematics (require context) + def get_group_ee_pose(self, ctx: Any, group_id: PlanningGroupID) -> PoseStamped: + """Get pose for a planning group's target frame.""" + ... + + def get_group_jacobian(self, ctx: Any, group_id: PlanningGroupID) -> NDArray[np.float64]: + """Get planning group target-frame Jacobian over the group's selected joints.""" + ... + def get_ee_pose(self, ctx: Any, robot_id: WorldRobotID) -> PoseStamped: - """Get end-effector pose.""" + """Get pose for a robot's unique pose-targetable planning group. + + TODO: deprecate this. + """ ... def get_link_pose( @@ -167,7 +181,7 @@ def get_link_pose( ... def get_jacobian(self, ctx: Any, robot_id: WorldRobotID) -> NDArray[np.float64]: - """Get end-effector Jacobian (6 x n_joints).""" + """Get Jacobian for a robot's unique pose-targetable planning group.""" ... @@ -194,16 +208,16 @@ def publish_visualization(self, ctx: Any | None = None) -> None: """Publish current state to visualization.""" ... - def show_preview(self, robot_id: WorldRobotID) -> None: - """Show the preview representation for a robot.""" + def show_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: + """Show preview representations for the selected planning groups.""" ... - def hide_preview(self, robot_id: WorldRobotID) -> None: - """Hide the preview representation for a robot.""" + def hide_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: + """Hide preview representations for the selected planning groups.""" ... - def animate_path(self, robot_id: WorldRobotID, path: JointPath, duration: float = 3.0) -> None: - """Animate a path in visualization.""" + def animate_plan(self, plan: GeneratedPlan, duration: float = 3.0) -> None: + """Animate a generated plan in visualization.""" ... def close(self) -> None: @@ -213,7 +227,7 @@ def close(self) -> None: @runtime_checkable class KinematicsSpec(Protocol): - """Protocol for inverse kinematics solvers. Stateless, uses WorldSpec for FK/collision.""" + """Protocol for inverse kinematics solvers. Stateless and IK-only.""" def solve( self, @@ -223,10 +237,22 @@ def solve( seed: JointState | None = None, position_tolerance: float = 0.001, orientation_tolerance: float = 0.01, - check_collision: bool = True, max_attempts: int = 10, ) -> IKResult: - """Solve IK with optional collision checking.""" + """Solve a single robot-scoped IK target.""" + ... + + def solve_pose_targets( + self, + world: WorldSpec, + pose_targets: dict[PlanningGroup, PoseStamped], + auxiliary_groups: list[PlanningGroup] | tuple[PlanningGroup, ...] = (), + seed: JointState | None = None, + position_tolerance: float = 0.001, + orientation_tolerance: float = 0.01, + max_attempts: int = 10, + ) -> IKResult: + """Solve pose targets over planning groups plus request-scoped auxiliaries.""" ... @@ -254,6 +280,17 @@ def plan_joint_path( """Plan a collision-free joint-space path.""" ... + def plan_selected_joint_path( + self, + world: WorldSpec, + selection: PlanningGroupSelection, + start: JointState, + goal: JointState, + timeout: float = 10.0, + ) -> PlanningResult: + """Plan over an explicit planning-group selection.""" + ... + def get_name(self) -> str: """Get planner name.""" ... diff --git a/dimos/manipulation/planning/test_planning_group_identifiers.py b/dimos/manipulation/planning/test_planning_group_identifiers.py new file mode 100644 index 0000000000..d728123e8f --- /dev/null +++ b/dimos/manipulation/planning/test_planning_group_identifiers.py @@ -0,0 +1,39 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for planning group identifier grammar.""" + +from __future__ import annotations + +import pytest + +from dimos.manipulation.planning.groups.identifiers import ( + local_joint_name_from_global, + parse_global_joint_name, +) + + +def test_parse_global_joint_name_returns_robot_and_local_joint() -> None: + assert parse_global_joint_name("left_arm/joint1") == ("left_arm", "joint1") + + +@pytest.mark.parametrize("name", ["joint1", "/joint1", "left_arm/", "left_arm/foo/bar"]) +def test_parse_global_joint_name_rejects_invalid_names(name: str) -> None: + with pytest.raises(ValueError, match="Invalid global joint name"): + parse_global_joint_name(name) + + +def test_local_joint_name_from_global_requires_matching_robot() -> None: + with pytest.raises(ValueError, match="does not belong to robot"): + local_joint_name_from_global("right_arm", "left_arm/joint1") diff --git a/dimos/manipulation/planning/test_planning_group_joints.py b/dimos/manipulation/planning/test_planning_group_joints.py new file mode 100644 index 0000000000..0391b1b136 --- /dev/null +++ b/dimos/manipulation/planning/test_planning_group_joints.py @@ -0,0 +1,63 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for planning group joint-name normalization.""" + +from __future__ import annotations + +import pytest + +from dimos.manipulation.planning.groups.joints import joint_target_to_global_names +from dimos.manipulation.planning.groups.models import PlanningGroup +from dimos.msgs.sensor_msgs.JointState import JointState + + +def _make_group() -> PlanningGroup: + return PlanningGroup( + id="left/arm", + robot_name="left", + group_name="arm", + joint_names=("left/j1", "left/j2", "left/j3"), + local_joint_names=("j1", "j2", "j3"), + base_link="base", + tip_link="ee", + ) + + +def test_joint_target_to_global_names_accepts_named_global_targets_in_group_order() -> None: + group = _make_group() + target = JointState(name=["left/j3", "left/j1", "left/j2"], position=[3.0, 1.0, 2.0]) + + normalized = joint_target_to_global_names(group, target) + + assert normalized.name == ["left/j1", "left/j2", "left/j3"] + assert normalized.position == [1.0, 2.0, 3.0] + + +def test_joint_target_to_global_names_accepts_named_local_targets_in_group_order() -> None: + group = _make_group() + target = JointState(name=["j2", "j3", "j1"], position=[2.0, 3.0, 1.0]) + + normalized = joint_target_to_global_names(group, target) + + assert normalized.name == ["left/j1", "left/j2", "left/j3"] + assert normalized.position == [1.0, 2.0, 3.0] + + +def test_joint_target_to_global_names_rejects_mixed_global_and_local_target_names() -> None: + group = _make_group() + target = JointState(name=["left/j1", "j2", "left/j3"], position=[1.0, 2.0, 3.0]) + + with pytest.raises(ValueError, match="mixes global and local joint names"): + joint_target_to_global_names(group, target) diff --git a/dimos/manipulation/planning/test_planning_groups.py b/dimos/manipulation/planning/test_planning_groups.py new file mode 100644 index 0000000000..dd95d1b961 --- /dev/null +++ b/dimos/manipulation/planning/test_planning_groups.py @@ -0,0 +1,224 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for planning group discovery.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from dimos.manipulation.planning.groups.discovery import ( + FALLBACK_PLANNING_GROUP_NAME, + PlanningGroupDiscoveryError, + discover_planning_group_definitions, + generate_fallback_planning_group, + parse_srdf_planning_groups, +) +from dimos.robot.model_parser import JointDescription, ModelDescription + + +def _serial_model(*joint_types: str) -> ModelDescription: + joints = [ + JointDescription( + name=f"joint{i + 1}", + type=joint_type, + parent_link=f"link{i}", + child_link=f"link{i + 1}", + ) + for i, joint_type in enumerate(joint_types) + ] + return ModelDescription( + joints=joints, + root_link="link0", + links=[f"link{i}" for i in range(len(joint_types) + 1)], + ) + + +def _branching_model() -> ModelDescription: + return ModelDescription( + joints=[ + JointDescription( + name="left_joint", + type="revolute", + parent_link="base", + child_link="left_link", + ), + JointDescription( + name="right_joint", + type="revolute", + parent_link="base", + child_link="right_link", + ), + ], + root_link="base", + links=["base", "left_link", "right_link"], + ) + + +def _write_srdf(tmp_path: Path, body: str) -> Path: + srdf_path = tmp_path / "robot.srdf" + srdf_path.write_text(f"{body}") + return srdf_path + + +def test_parse_srdf_chain_group(tmp_path: Path) -> None: + model = _serial_model("revolute", "revolute", "revolute") + srdf_path = _write_srdf( + tmp_path, + "", + ) + + groups = parse_srdf_planning_groups( + srdf_path, + model=model, + controllable_joint_names=["joint1", "joint2", "joint3"], + ) + + assert len(groups) == 1 + assert groups[0].name == "arm" + assert groups[0].joint_names == ("joint1", "joint2", "joint3") + assert groups[0].base_link == "link0" + assert groups[0].tip_link == "link3" + assert groups[0].source == "srdf" + + +def test_parse_srdf_ordered_joint_list_group(tmp_path: Path) -> None: + model = _serial_model("revolute", "prismatic", "revolute") + srdf_path = _write_srdf( + tmp_path, + """ + + + + + + """, + ) + + groups = parse_srdf_planning_groups( + srdf_path, + model=model, + controllable_joint_names=["joint1", "joint2", "joint3"], + ) + + assert len(groups) == 1 + assert groups[0].joint_names == ("joint1", "joint2", "joint3") + assert groups[0].base_link == "link0" + assert groups[0].tip_link == "link3" + + +def test_parse_srdf_skips_unsupported_groups_and_ignores_end_effector( + tmp_path: Path, +) -> None: + model = _serial_model("revolute", "revolute") + srdf_path = _write_srdf( + tmp_path, + """ + + + + + """, + ) + + with pytest.warns(UserWarning) as warnings: + groups = parse_srdf_planning_groups( + srdf_path, + model=model, + controllable_joint_names=["joint1", "joint2"], + ) + + assert [group.name for group in groups] == ["arm"] + warning_text = "\n".join(str(warning.message) for warning in warnings) + assert "Skipping unsupported SRDF planning group links" in warning_text + assert "Skipping unsupported SRDF planning group nested" in warning_text + + +def test_fallback_generates_manipulator_for_unambiguous_serial_chain() -> None: + model = _serial_model("revolute", "prismatic", "revolute") + + group = generate_fallback_planning_group( + model=model, + controllable_joint_names=["joint2", "joint1", "joint3"], + ) + + assert group.name == FALLBACK_PLANNING_GROUP_NAME + assert group.joint_names == ("joint1", "joint2", "joint3") + assert group.base_link == "link0" + assert group.tip_link == "link3" + assert group.source == "fallback" + + +def test_fallback_strips_terminal_prismatic_joints() -> None: + model = _serial_model("revolute", "revolute", "prismatic") + + group = generate_fallback_planning_group( + model=model, + controllable_joint_names=["joint1", "joint2", "joint3"], + ) + + assert group.joint_names == ("joint1", "joint2") + assert group.tip_link == "link2" + + +def test_fallback_rejects_branching_model() -> None: + with pytest.raises(PlanningGroupDiscoveryError, match="branch"): + generate_fallback_planning_group( + model=_branching_model(), + controllable_joint_names=["left_joint", "right_joint"], + ) + + +def test_discovery_prefers_explicit_srdf_over_fallback(tmp_path: Path) -> None: + model = _serial_model("revolute", "revolute") + model_path = tmp_path / "robot.urdf" + model_path.write_text("") + srdf_path = _write_srdf( + tmp_path, + "", + ) + + groups = discover_planning_group_definitions( + robot_name="robot", + model_path=model_path, + model=model, + controllable_joint_names=["joint1", "joint2"], + srdf_path=srdf_path, + ) + + assert [group.name for group in groups] == ["srdf_arm"] + + +def test_discovery_auto_discovers_srdf_with_warning( + tmp_path: Path, +) -> None: + model = _serial_model("revolute") + model_path = tmp_path / "robot.urdf" + model_path.write_text("") + _write_srdf( + tmp_path, + "", + ) + + with pytest.warns(UserWarning, match="Auto-discovered SRDF"): + groups = discover_planning_group_definitions( + robot_name="robot", + model_path=model_path, + model=model, + controllable_joint_names=["joint1"], + ) + + assert [group.name for group in groups] == ["auto_arm"] diff --git a/dimos/manipulation/planning/utils/mesh_utils.py b/dimos/manipulation/planning/utils/mesh_utils.py index 988a4e5e8e..7ab3941c7a 100644 --- a/dimos/manipulation/planning/utils/mesh_utils.py +++ b/dimos/manipulation/planning/utils/mesh_utils.py @@ -37,6 +37,7 @@ import shutil import tempfile from typing import TYPE_CHECKING +import xml.etree.ElementTree as ET from dimos.utils.logging_config import setup_logger @@ -55,6 +56,7 @@ def prepare_urdf_for_drake( package_paths: dict[str, Path] | None = None, xacro_args: dict[str, str] | None = None, convert_meshes: bool = False, + strip_world_joint_child_link: str | None = None, ) -> str: """Prepare a URDF/xacro file for use with Drake. @@ -68,6 +70,9 @@ def prepare_urdf_for_drake( package_paths: Dict mapping package names to filesystem paths xacro_args: Arguments to pass to xacro processor convert_meshes: Convert DAE/STL meshes to OBJ for Drake compatibility + strip_world_joint_child_link: If set, remove a fixed URDF joint from + world to this child link so callers can apply instance placement via + RobotModelConfig.base_pose instead of model-authored placement. Returns: Path to the prepared URDF file (may be cached) @@ -77,7 +82,9 @@ def prepare_urdf_for_drake( xacro_args = xacro_args or {} # Generate cache key - cache_key = _generate_cache_key(urdf_path, package_paths, xacro_args, convert_meshes) + cache_key = _generate_cache_key( + urdf_path, package_paths, xacro_args, convert_meshes, strip_world_joint_child_link + ) cache_path = _CACHE_DIR / cache_key / urdf_path.stem cache_path.mkdir(parents=True, exist_ok=True) cached_urdf = cache_path / f"{urdf_path.stem}.urdf" @@ -96,6 +103,9 @@ def prepare_urdf_for_drake( # Strip transmission blocks (Drake doesn't need them, and they can cause issues) urdf_content = _strip_transmission_blocks(urdf_content) + if strip_world_joint_child_link is not None: + urdf_content = _strip_fixed_world_joint(urdf_content, strip_world_joint_child_link) + # Resolve package:// URIs urdf_content = _resolve_package_uris(urdf_content, package_paths, cache_path) @@ -115,6 +125,7 @@ def _generate_cache_key( package_paths: dict[str, Path], xacro_args: dict[str, str], convert_meshes: bool, + strip_world_joint_child_link: str | None, ) -> str: """Generate a cache key for the URDF configuration. @@ -125,9 +136,12 @@ def _generate_cache_key( # Version number to invalidate cache when processing logic changes # Increment this when adding new processing steps (e.g., stripping transmission blocks) - processing_version = "v2" + processing_version = "v3" - key_data = f"{processing_version}:{urdf_path}:{mtime}:{sorted(package_paths.items())}:{sorted(xacro_args.items())}:{convert_meshes}" + key_data = ( + f"{processing_version}:{urdf_path}:{mtime}:{sorted(package_paths.items())}:" + f"{sorted(xacro_args.items())}:{convert_meshes}:{strip_world_joint_child_link}" + ) return hashlib.md5(key_data.encode()).hexdigest()[:16] @@ -175,6 +189,52 @@ def _strip_transmission_blocks(urdf_content: str) -> str: return result +def _strip_fixed_world_joint(urdf_content: str, child_link: str) -> str: + """Remove a fixed world-to-base joint so base_pose can own placement. + + ``RobotModelConfig.base_pose`` is the canonical planning-world placement. + Some URDF/xacro models also include a fixed ``world -> base`` joint; if Drake + loads that joint and the caller applies ``base_pose``, placement can be + double-applied or constrained by a model-authored weld. Strip only the fixed + world joint to the configured child link and then remove an unreferenced + ``world`` link. + """ + try: + root = ET.fromstring(urdf_content) + except ET.ParseError: + logger.warning("Could not parse URDF while stripping world joint", exc_info=True) + return urdf_content + + removed = False + for joint in list(root.findall("joint")): + if joint.attrib.get("type") != "fixed": + continue + parent = joint.find("parent") + child = joint.find("child") + if parent is None or child is None: + continue + if parent.attrib.get("link") == "world" and child.attrib.get("link") == child_link: + root.remove(joint) + removed = True + + if not removed: + return urdf_content + + referenced_links = set() + for joint in root.findall("joint"): + parent = joint.find("parent") + child = joint.find("child") + if parent is not None: + referenced_links.add(parent.attrib.get("link")) + if child is not None: + referenced_links.add(child.attrib.get("link")) + for link in list(root.findall("link")): + if link.attrib.get("name") == "world" and "world" not in referenced_links: + root.remove(link) + + return ET.tostring(root, encoding="unicode") + + def _resolve_package_uris( urdf_content: str, package_paths: dict[str, Path], diff --git a/dimos/manipulation/planning/utils/test_mesh_utils.py b/dimos/manipulation/planning/utils/test_mesh_utils.py new file mode 100644 index 0000000000..fa3004c7ce --- /dev/null +++ b/dimos/manipulation/planning/utils/test_mesh_utils.py @@ -0,0 +1,47 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import xml.etree.ElementTree as ET + +from dimos.manipulation.planning.utils.mesh_utils import prepare_urdf_for_drake + + +def test_prepare_urdf_can_strip_fixed_world_joint_for_base_pose_placement( + tmp_path, +) -> None: + urdf_path = tmp_path / "robot.urdf" + urdf_path.write_text( + """ + + + + + + + + + +""".strip() + ) + + prepared_path = prepare_urdf_for_drake( + urdf_path, + strip_world_joint_child_link="link_base", + ) + + root = ET.parse(prepared_path).getroot() + assert [joint.attrib["name"] for joint in root.findall("joint")] == [] + assert [link.attrib["name"] for link in root.findall("link")] == ["link_base"] diff --git a/dimos/manipulation/planning/world/drake_world.py b/dimos/manipulation/planning/world/drake_world.py index ca426ba340..b639674c4c 100644 --- a/dimos/manipulation/planning/world/drake_world.py +++ b/dimos/manipulation/planning/world/drake_world.py @@ -25,11 +25,17 @@ import numpy as np +from dimos.manipulation.planning.groups.identifiers import ( + assert_local_joint_names, + make_global_joint_name, +) +from dimos.manipulation.planning.groups.registry import PlanningGroupRegistry from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.manipulation.planning.spec.enums import ObstacleType from dimos.manipulation.planning.spec.models import ( - JointPath, + GeneratedPlan, Obstacle, + PlanningGroupID, PlanningSceneInfo, WorldRobotID, ) @@ -38,7 +44,7 @@ from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Sequence from numpy.typing import NDArray @@ -84,6 +90,22 @@ logger = setup_logger() +def _pose_stamped_from_drake_transform(transform: RigidTransform) -> PoseStamped: + """Convert a Drake RigidTransform-like object to a world-frame pose.""" + position = transform.translation() + quaternion = transform.rotation().ToQuaternion() + return PoseStamped( + frame_id="world", + position=[float(position[0]), float(position[1]), float(position[2])], + orientation=[ + float(quaternion.x()), + float(quaternion.y()), + float(quaternion.z()), + float(quaternion.w()), + ], + ) + + @dataclass class _RobotData: """Internal data for tracking a robot in the world.""" @@ -92,8 +114,8 @@ class _RobotData: config: RobotModelConfig model_instance: Any # ModelInstanceIndex joint_indices: list[int] # Indices into plant's position vector - ee_frame: Any # BodyFrame for end-effector - base_frame: Any # BodyFrame for base + ee_frame: Any | None # Compatibility robot-scoped end-effector frame + base_frame: Any # Compatibility robot-scoped base frame preview_model_instance: Any = None # ModelInstanceIndex for preview (yellow) robot preview_joint_indices: list[int] = field(default_factory=list) @@ -184,6 +206,7 @@ def __init__(self, time_step: float = 0.0, enable_viz: bool = False) -> None: # Tracking data self._robots: dict[WorldRobotID, _RobotData] = {} + self._planning_groups = PlanningGroupRegistry() self._obstacles: dict[str, _ObstacleData] = {} self._robot_counter = 0 self._obstacle_counter = 0 @@ -202,6 +225,9 @@ def add_robot(self, config: RobotModelConfig) -> WorldRobotID: """Add a robot to the world. Returns robot_id. Same model_path + base_pose reuses the model instance (e.g. two arms in one URDF). + base_pose/base_link/end_effector_link remain compatibility fields for + placement and robot-scoped helpers; group-aware planning should use + planning group base/tip links. """ if self._finalized: raise RuntimeError("Cannot add robot after world is finalized") @@ -215,9 +241,7 @@ def add_robot(self, config: RobotModelConfig) -> WorldRobotID: self._validate_joints(config, model_instance) - ee_frame = self._plant.GetBodyByName( - config.end_effector_link, model_instance - ).body_frame() + ee_frame = self._legacy_ee_frame(config, model_instance) base_frame = self._plant.GetBodyByName(config.base_link, model_instance).body_frame() # Preview (yellow ghost) — always a separate instance per robot @@ -235,10 +259,17 @@ def add_robot(self, config: RobotModelConfig) -> WorldRobotID: base_frame=base_frame, preview_model_instance=preview_model_instance, ) + self._planning_groups.add_robot(config) logger.info(f"Added robot '{robot_id}' ({config.name})") return robot_id + def _legacy_ee_frame(self, config: RobotModelConfig, model_instance: Any) -> Any | None: + """Resolve compatibility robot-scoped EE frame, if available.""" + if config.end_effector_link is None: + return None + return self._plant.GetBodyByName(config.end_effector_link, model_instance).body_frame() + def _load_model(self, config: RobotModelConfig) -> Any: """Load robot model (URDF/xacro/MJCF) and return model instance.""" original_path = config.model_path.resolve() @@ -255,6 +286,9 @@ def _load_model(self, config: RobotModelConfig) -> Any: package_paths=config.package_paths, xacro_args=config.xacro_args, convert_meshes=config.auto_convert_meshes, + strip_world_joint_child_link=config.base_link + if config.strip_model_world_joint + else None, ) prepared_path_obj = Path(prepared_path) @@ -317,6 +351,18 @@ def get_robot_config(self, robot_id: WorldRobotID) -> RobotModelConfig: raise KeyError(f"Robot '{robot_id}' not found") return self._robots[robot_id].config + def _get_robot_data_by_name(self, robot_name: str) -> _RobotData: + matches = [ + robot_data + for robot_data in self._robots.values() + if robot_data.config.name == robot_name + ] + if not matches: + raise KeyError(f"Robot '{robot_name}' not found for planning group resolution") + if len(matches) > 1: + raise ValueError(f"Robot name '{robot_name}' is not unique in planning world") + return matches[0] + def get_joint_limits( self, robot_id: WorldRobotID ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: @@ -686,7 +732,7 @@ def finalize(self) -> None: self.publish_visualization() # Hide all preview robots initially for robot_id in self._robots: - self.hide_preview(robot_id) + self._hide_preview_robot(robot_id) @property def is_finalized(self) -> bool: @@ -768,8 +814,7 @@ def sync_from_joint_state(self, robot_id: WorldRobotID, joint_state: JointState) if not self._finalized or self._plant_context is None: return # Silently ignore before finalization - # Extract positions as numpy array for internal use - positions = np.array(joint_state.position, dtype=np.float64) + positions = self._positions_for_robot_state(robot_id, joint_state) with self._lock: self._set_positions_internal(self._plant_context, robot_id, positions) @@ -787,8 +832,7 @@ def set_joint_state( if not self._finalized: raise RuntimeError("World must be finalized first") - # Extract positions as numpy array for internal use - positions = np.array(joint_state.position, dtype=np.float64) + positions = self._positions_for_robot_state(robot_id, joint_state) # Get plant context from diagram context plant_ctx = self._diagram.GetMutableSubsystemContext(self._plant, ctx) @@ -809,6 +853,37 @@ def _set_positions_internal( self._plant.SetPositions(plant_ctx, full_positions) + def _positions_for_robot_state( + self, robot_id: WorldRobotID, joint_state: JointState + ) -> NDArray[np.float64]: + if robot_id not in self._robots: + raise KeyError(f"Robot '{robot_id}' not found") + robot_data = self._robots[robot_id] + local_joint_names = robot_data.config.joint_names + + if not joint_state.name: + if len(joint_state.position) != len(local_joint_names): + raise ValueError( + f"JointState position length {len(joint_state.position)} does not match " + f"robot {robot_data.config.name} joint count {len(local_joint_names)}" + ) + return np.array(joint_state.position, dtype=np.float64) + + state_by_local_name: dict[str, float] = {} + if len(joint_state.name) != len(joint_state.position): + raise ValueError("JointState name and position lengths must match") + + assert_local_joint_names(joint_state.name) + for name, position in zip(joint_state.name, joint_state.position, strict=False): + state_by_local_name[name] = float(position) + + missing = [name for name in local_joint_names if name not in state_by_local_name] + if missing: + raise ValueError( + f"JointState for robot {robot_data.config.name} is missing joints: {missing}" + ) + return np.array([state_by_local_name[name] for name in local_joint_names], dtype=np.float64) + def get_joint_state(self, ctx: Context, robot_id: WorldRobotID) -> JointState: """Get robot joint state from given context.""" if not self._finalized: @@ -822,7 +897,10 @@ def get_joint_state(self, ctx: Context, robot_id: WorldRobotID) -> JointState: full_positions = self._plant.GetPositions(plant_ctx) positions = [float(full_positions[idx]) for idx in robot_data.joint_indices] - return JointState(name=robot_data.config.joint_names, position=positions) + return JointState( + name=list(robot_data.config.joint_names), + position=positions, + ) # Collision Checking (context-based) @@ -903,8 +981,28 @@ def check_edge_collision_free( # Forward Kinematics (context-based) + def get_group_ee_pose(self, ctx: Context, group_id: PlanningGroupID) -> PoseStamped: + """Get pose for a planning group's target frame.""" + if not self._finalized: + raise RuntimeError("World must be finalized first") + + group = self._planning_groups.get(group_id) + if group.tip_link is None: + raise ValueError(f"Planning group '{group_id}' has no pose target frame") + + robot_data = self._get_robot_data_by_name(group.robot_name) + plant_ctx = self._diagram.GetSubsystemContext(self._plant, ctx) + + try: + tip_body = self._plant.GetBodyByName(group.tip_link, robot_data.model_instance) + except RuntimeError: + raise KeyError(f"Planning group '{group_id}' target link '{group.tip_link}' not found") + + tip_pose = self._plant.EvalBodyPoseInWorld(plant_ctx, tip_body) + return _pose_stamped_from_drake_transform(tip_pose) + def get_ee_pose(self, ctx: Context, robot_id: WorldRobotID) -> PoseStamped: - """Get end-effector pose.""" + """Get pose for a robot's compatibility end-effector frame.""" if not self._finalized: raise RuntimeError("World must be finalized first") @@ -912,20 +1010,16 @@ def get_ee_pose(self, ctx: Context, robot_id: WorldRobotID) -> PoseStamped: raise KeyError(f"Robot '{robot_id}' not found") robot_data = self._robots[robot_id] + if robot_data.ee_frame is None: + raise ValueError( + f"Robot '{robot_id}' has no robot-scoped end-effector link; " + "use get_group_ee_pose() with an explicit planning group ID" + ) plant_ctx = self._diagram.GetSubsystemContext(self._plant, ctx) ee_body = robot_data.ee_frame.body() X_WE = self._plant.EvalBodyPoseInWorld(plant_ctx, ee_body) - - # Extract position and quaternion from Drake transform - pos = X_WE.translation() - quat = X_WE.rotation().ToQuaternion() # Drake returns [w, x, y, z] - - return PoseStamped( - frame_id="world", - position=[float(pos[0]), float(pos[1]), float(pos[2])], - orientation=[float(quat.x()), float(quat.y()), float(quat.z()), float(quat.w())], - ) + return _pose_stamped_from_drake_transform(X_WE) def get_link_pose( self, ctx: Context, robot_id: WorldRobotID, link_name: str @@ -951,7 +1045,7 @@ def get_link_pose( return result # type: ignore[no-any-return] def get_jacobian(self, ctx: Context, robot_id: WorldRobotID) -> NDArray[np.float64]: - """Get geometric Jacobian (6 x n_joints). + """Get robot-scoped geometric Jacobian for the compatibility EE frame. Rows: [vx, vy, vz, wx, wy, wz] (linear, then angular) """ @@ -986,6 +1080,52 @@ def get_jacobian(self, ctx: Context, robot_id: WorldRobotID) -> NDArray[np.float return J_reordered + def get_group_jacobian(self, ctx: Context, group_id: PlanningGroupID) -> NDArray[np.float64]: + """Get geometric Jacobian for a planning group's target frame. + + Rows: [vx, vy, vz, wx, wy, wz] (linear, then angular). Columns follow + the resolved planning group's joint order. + """ + if not self._finalized: + raise RuntimeError("World must be finalized first") + + group = self._planning_groups.get(group_id) + if group.tip_link is None: + raise ValueError(f"Planning group '{group_id}' has no pose target frame") + + robot_data = self._get_robot_data_by_name(group.robot_name) + plant_ctx = self._diagram.GetSubsystemContext(self._plant, ctx) + + try: + tip_body = self._plant.GetBodyByName(group.tip_link, robot_data.model_instance) + except RuntimeError: + raise KeyError(f"Planning group '{group_id}' target link '{group.tip_link}' not found") + + jacobian_full = self._plant.CalcJacobianSpatialVelocity( + plant_ctx, + JacobianWrtVariable.kQDot, + tip_body.body_frame(), + np.array([0.0, 0.0, 0.0]), # type: ignore[arg-type] + self._plant.world_frame(), + self._plant.world_frame(), + ) + + joint_indices_by_name = dict( + zip(robot_data.config.joint_names, robot_data.joint_indices, strict=False) + ) + jacobian_group = np.zeros((6, len(group.local_joint_names))) + for index, local_joint_name in enumerate(group.local_joint_names): + try: + joint_index = joint_indices_by_name[local_joint_name] + except KeyError: + raise ValueError( + f"Planning group '{group_id}' references non-controllable joint " + f"'{local_joint_name}'" + ) + jacobian_group[:, index] = jacobian_full[:, joint_index] + + return np.vstack([jacobian_group[3:6, :], jacobian_group[0:3, :]]) + # Visualization def initialize_scene(self, scene: PlanningSceneInfo) -> None: @@ -1021,8 +1161,19 @@ def _set_preview_positions( full_positions[idx] = positions[i] self._plant.SetPositions(plant_ctx, full_positions) - def show_preview(self, robot_id: WorldRobotID) -> None: - """Show the preview (yellow ghost) robot in Meshcat.""" + def _preview_robot_ids_for_groups( + self, group_ids: Sequence[PlanningGroupID] + ) -> list[WorldRobotID]: + """Resolve planning groups to stable preview robot IDs.""" + robot_ids: list[WorldRobotID] = [] + for group in self._planning_groups.select(tuple(group_ids)).groups: + robot_id = self._get_robot_data_by_name(group.robot_name).robot_id + if robot_id not in robot_ids: + robot_ids.append(robot_id) + return robot_ids + + def _show_preview_robot(self, robot_id: WorldRobotID) -> None: + """Show one preview (yellow ghost) robot in Meshcat.""" if self._meshcat is None: return robot_data = self._robots.get(robot_id) @@ -1031,8 +1182,13 @@ def show_preview(self, robot_id: WorldRobotID) -> None: model_name = self._plant.GetModelInstanceName(robot_data.preview_model_instance) self._meshcat.SetProperty(f"visualizer/{model_name}", "visible", True) - def hide_preview(self, robot_id: WorldRobotID) -> None: - """Hide the preview (yellow ghost) robot in Meshcat.""" + def show_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: + """Show preview robots affected by planning groups.""" + for robot_id in self._preview_robot_ids_for_groups(group_ids): + self._show_preview_robot(robot_id) + + def _hide_preview_robot(self, robot_id: WorldRobotID) -> None: + """Hide one preview (yellow ghost) robot in Meshcat.""" if self._meshcat is None: return robot_data = self._robots.get(robot_id) @@ -1041,32 +1197,63 @@ def hide_preview(self, robot_id: WorldRobotID) -> None: model_name = self._plant.GetModelInstanceName(robot_data.preview_model_instance) self._meshcat.SetProperty(f"visualizer/{model_name}", "visible", False) - def animate_path( + def hide_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: + """Hide preview robots affected by planning groups.""" + for robot_id in self._preview_robot_ids_for_groups(group_ids): + self._hide_preview_robot(robot_id) + + def _preview_positions_for_waypoint( self, robot_id: WorldRobotID, - path: JointPath, - duration: float = 3.0, - ) -> None: - """Animate a path using the preview (yellow ghost) robot. + selected_positions_by_name: dict[str, float], + current_positions: NDArray[np.float64], + ) -> NDArray[np.float64]: + """Build full local preview positions for one robot from selected globals.""" + robot_data = self._robots[robot_id] + positions = current_positions.copy() + for index, local_name in enumerate(robot_data.config.joint_names): + global_name = make_global_joint_name(robot_data.config.name, local_name) + if global_name in selected_positions_by_name: + positions[index] = selected_positions_by_name[global_name] + return positions + + def animate_plan(self, plan: GeneratedPlan, duration: float = 3.0) -> None: + """Animate a generated plan using preview (yellow ghost) robots. The preview stays visible after animation completes. """ - if self._meshcat is None or len(path) < 2: + if self._meshcat is None or len(plan.path) < 2: return - robot_data = self._robots.get(robot_id) - if robot_data is None or robot_data.preview_model_instance is None: + robot_ids = [ + robot_id + for robot_id in self._preview_robot_ids_for_groups(plan.group_ids) + if self._robots[robot_id].preview_model_instance is not None + ] + if not robot_ids: return import time - self.show_preview(robot_id) - dt = duration / (len(path) - 1) - for joint_state in path: - positions = np.array(joint_state.position, dtype=np.float64) + self.show_preview(plan.group_ids) + dt = duration / (len(plan.path) - 1) + for joint_state in plan.path: + selected_positions_by_name = dict( + zip(joint_state.name, joint_state.position, strict=True) + ) with self._lock: assert self._plant_context is not None - self._set_preview_positions(self._plant_context, robot_id, positions) + for robot_id in robot_ids: + robot_data = self._robots[robot_id] + current_positions = self._plant.GetPositions( + self._plant_context, robot_data.model_instance + ) + positions = self._preview_positions_for_waypoint( + robot_id, + selected_positions_by_name, + np.array(current_positions, dtype=np.float64), + ) + self._set_preview_positions(self._plant_context, robot_id, positions) self.publish_visualization() time.sleep(dt) diff --git a/dimos/manipulation/planning/world/test_drake_world_planning_groups.py b/dimos/manipulation/planning/world/test_drake_world_planning_groups.py new file mode 100644 index 0000000000..49d3a3211e --- /dev/null +++ b/dimos/manipulation/planning/world/test_drake_world_planning_groups.py @@ -0,0 +1,210 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DrakeWorld planning group name/world resolution.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from dimos.manipulation.planning.groups.models import PlanningGroupDefinition +from dimos.manipulation.planning.groups.registry import PlanningGroupRegistry +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.manipulation.planning.world.drake_world import DrakeWorld, _RobotData +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState + + +def _pose() -> PoseStamped: + return PoseStamped(position=[0, 0, 0], orientation=[0, 0, 0, 1]) + + +def _config( + name: str, + joint_names: list[str], + groups: list[PlanningGroupDefinition], +) -> RobotModelConfig: + return RobotModelConfig( + name=name, + model_path=Path("robot.urdf"), + base_pose=_pose(), + joint_names=joint_names, + end_effector_link="tool0", + base_link="base_link", + planning_groups=groups, + ) + + +def _world(*configs: RobotModelConfig) -> DrakeWorld: + world = DrakeWorld.__new__(DrakeWorld) + world._robots = { + f"robot_{index}": _RobotData( + robot_id=f"robot_{index}", + config=config, + model_instance=None, + joint_indices=[], + ee_frame=None, + base_frame=None, + ) + for index, config in enumerate(configs, start=1) + } + world._planning_groups = PlanningGroupRegistry(configs) + return world + + +def _arm_group(*joint_names: str) -> PlanningGroupDefinition: + return PlanningGroupDefinition( + name="arm", + joint_names=joint_names, + base_link="base_link", + tip_link="tool0", + source="srdf", + ) + + +def test_planning_group_registry_returns_stable_ids_and_global_joint_names() -> None: + config = _config("left", ["joint1", "joint2"], [_arm_group("joint1", "joint2")]) + registry = PlanningGroupRegistry([config]) + + groups = registry.list() + + assert len(groups) == 1 + assert groups[0].id == "left/arm" + assert groups[0].robot_name == "left" + assert groups[0].group_name == "arm" + assert groups[0].joint_names == ("left/joint1", "left/joint2") + assert groups[0].local_joint_names == ("joint1", "joint2") + + +def test_robot_model_config_allows_planning_groups_without_robot_scoped_ee() -> None: + config = RobotModelConfig( + name="left", + model_path=Path("robot.urdf"), + joint_names=["joint1"], + planning_groups=[_arm_group("joint1")], + ) + registry = PlanningGroupRegistry([config]) + + groups = registry.list() + + assert config.end_effector_link is None + assert groups[0].id == "left/arm" + + +def test_duplicate_local_joint_names_across_robots_are_disambiguated() -> None: + registry = PlanningGroupRegistry( + [ + _config("left", ["joint1"], [_arm_group("joint1")]), + _config("right", ["joint1"], [_arm_group("joint1")]), + ] + ) + + groups = registry.list() + + assert [group.id for group in groups] == ["left/arm", "right/arm"] + assert [group.joint_names for group in groups] == [("left/joint1",), ("right/joint1",)] + + +def test_planning_group_selection_returns_ordered_global_joint_names() -> None: + registry = PlanningGroupRegistry( + [ + _config("left", ["joint1", "joint2"], [_arm_group("joint1", "joint2")]), + _config("right", ["joint1", "joint2"], [_arm_group("joint2")]), + ] + ) + + selection = registry.select(("left/arm", "right/arm")) + + assert list(selection.group_ids) == ["left/arm", "right/arm"] + assert list(selection.robot_names) == ["left", "right"] + assert [group.joint_names for group in selection.groups] == [ + ("left/joint1", "left/joint2"), + ("right/joint2",), + ] + assert list(selection.joint_names) == ["left/joint1", "left/joint2", "right/joint2"] + assert [group.local_joint_names for group in selection.groups] == [ + ("joint1", "joint2"), + ("joint2",), + ] + + +def test_planning_group_registry_unknown_group_raises_key_error() -> None: + registry = PlanningGroupRegistry([_config("left", ["joint1"], [_arm_group("joint1")])]) + + with pytest.raises(KeyError, match="Unknown planning group ID: left/gripper"): + registry.select(("left/gripper",)) + + +def test_planning_group_selection_overlapping_same_robot_groups_raise_value_error() -> None: + registry = PlanningGroupRegistry( + [ + _config( + "left", + ["joint1", "joint2"], + [ + _arm_group("joint1", "joint2"), + PlanningGroupDefinition( + name="wrist", + joint_names=("joint2",), + base_link="link1", + tip_link="tool0", + ), + ], + ) + ] + ) + + with pytest.raises(ValueError, match="overlap.*left/joint2"): + registry.select(("left/arm", "left/wrist")) + + +def test_positions_for_robot_state_accepts_local_joint_names_in_config_order() -> None: + world = _world(_config("left", ["joint1", "joint2"], [_arm_group("joint1", "joint2")])) + joint_state = JointState({"name": ["joint2", "joint1"], "position": [2.0, 1.0]}) + + positions = world._positions_for_robot_state("robot_1", joint_state) + + np.testing.assert_allclose(positions, np.array([1.0, 2.0])) + + +def test_positions_for_robot_state_rejects_global_joint_names() -> None: + world = _world(_config("left", ["joint1", "joint2"], [_arm_group("joint1", "joint2")])) + joint_state = JointState({"name": ["left/joint2", "left/joint1"], "position": [2.0, 1.0]}) + + with pytest.raises(ValueError, match="Invalid local joint name: 'left/joint2'"): + world._positions_for_robot_state("robot_1", joint_state) + + +def test_group_pose_rejects_group_without_target_frame() -> None: + world = _world( + _config( + "left", + ["joint1"], + [ + PlanningGroupDefinition( + name="waist", + joint_names=("joint1",), + base_link="base_link", + tip_link=None, + ) + ], + ) + ) + world._finalized = True + + with pytest.raises(ValueError, match="left/waist.*no pose target frame"): + world.get_group_ee_pose(None, "left/waist") diff --git a/dimos/manipulation/test_manipulation_module.py b/dimos/manipulation/test_manipulation_module.py index f8e914e3b3..f3ff29ace0 100644 --- a/dimos/manipulation/test_manipulation_module.py +++ b/dimos/manipulation/test_manipulation_module.py @@ -69,15 +69,6 @@ def _get_xarm7_config() -> RobotModelConfig: auto_convert_meshes=True, max_velocity=1.0, max_acceleration=2.0, - joint_name_mapping={ - "arm/joint1": "joint1", - "arm/joint2": "joint2", - "arm/joint3": "joint3", - "arm/joint4": "joint4", - "arm/joint5": "joint5", - "arm/joint6": "joint6", - "arm/joint7": "joint7", - }, coordinator_task_name="traj_arm", ) @@ -92,13 +83,13 @@ def joint_state_zeros(): """Create a JointState message with zeros for XArm7.""" return JointState( name=[ - "arm/joint1", - "arm/joint2", - "arm/joint3", - "arm/joint4", - "arm/joint5", - "arm/joint6", - "arm/joint7", + "test_arm/joint1", + "test_arm/joint2", + "test_arm/joint3", + "test_arm/joint4", + "test_arm/joint5", + "test_arm/joint6", + "test_arm/joint7", ], position=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], @@ -160,11 +151,8 @@ def test_plan_to_joints(self, module, joint_state_zeros): assert success is True assert module._state == ManipulationState.COMPLETED assert module.has_planned_path() is True - - assert "test_arm" in module._planned_trajectories - traj = module._planned_trajectories["test_arm"] - assert len(traj.points) > 1 - assert traj.duration > 0 + assert module._last_plan is not None + assert len(module._last_plan.path) > 1 def test_add_and_remove_obstacle(self, module, joint_state_zeros): """Test adding and removing obstacles.""" @@ -191,7 +179,6 @@ def test_robot_info(self, module): assert len(info["joint_names"]) == 7 assert info["end_effector_link"] == "link7" assert info["coordinator_task_name"] == "traj_arm" - assert info["has_joint_name_mapping"] is True def test_ee_pose(self, module, joint_state_zeros): """Test getting end-effector pose.""" @@ -204,20 +191,21 @@ def test_ee_pose(self, module, joint_state_zeros): assert hasattr(pose, "y") assert hasattr(pose, "z") - def test_trajectory_name_translation(self, module, joint_state_zeros): - """Test that trajectory joint names are translated for coordinator.""" + def test_planned_trajectory_uses_global_joint_names(self, module, joint_state_zeros): + """Test that planned trajectory joint names are global for coordinator.""" module._on_joint_state(joint_state_zeros) success = module.plan_to_joints(JointState(position=[0.05] * 7)) assert success is True - traj = module._planned_trajectories["test_arm"] - robot_config = module._robots["test_arm"][1] + mock_client = MagicMock() + mock_client.task_invoke.return_value = True + module._coordinator_client = mock_client - translated = module._translate_trajectory_to_coordinator(traj, robot_config) + assert module.execute() is True - for name in translated.joint_names: - assert name.startswith("arm_") # Should have arm_ prefix + trajectory = mock_client.task_invoke.call_args.args[2]["trajectory"] + assert trajectory.joint_names == [f"test_arm/joint{i}" for i in range(1, 8)] @pytest.mark.skipif(not _drake_available(), reason="Drake not installed") @@ -251,8 +239,7 @@ def test_execute_with_mock_coordinator(self, module, joint_state_zeros): assert method_name == "execute" trajectory = kwargs["trajectory"] assert len(trajectory.points) > 1 - # Joint names should be translated - assert all(n.startswith("arm_") for n in trajectory.joint_names) + assert trajectory.joint_names == [f"test_arm/joint{i}" for i in range(1, 8)] def test_execute_rejected_by_coordinator(self, module, joint_state_zeros): """Test handling of coordinator rejection.""" diff --git a/dimos/manipulation/test_manipulation_unit.py b/dimos/manipulation/test_manipulation_unit.py index 4c83e9c3ad..33ebb4d267 100644 --- a/dimos/manipulation/test_manipulation_unit.py +++ b/dimos/manipulation/test_manipulation_unit.py @@ -28,11 +28,17 @@ ManipulationModuleConfig, ManipulationState, ) +from dimos.manipulation.planning.groups.models import PlanningGroup, PlanningGroupSelection from dimos.manipulation.planning.kinematics.config import PinkKinematicsConfig from dimos.manipulation.planning.monitor.world_monitor import WorldMonitor from dimos.manipulation.planning.spec.config import RobotModelConfig -from dimos.manipulation.planning.spec.enums import IKStatus -from dimos.manipulation.planning.spec.models import IKResult, PlanningSceneInfo +from dimos.manipulation.planning.spec.enums import IKStatus, PlanningStatus +from dimos.manipulation.planning.spec.models import ( + CollisionCheckResult, + GeneratedPlan, + IKResult, + PlanningSceneInfo, +) from dimos.manipulation.planning.spec.protocols import VisualizationSpec from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped @@ -60,8 +66,8 @@ def robot_config(): @pytest.fixture -def robot_config_with_mapping(): - """Create a robot config with joint name mapping (dual-arm scenario).""" +def left_robot_config(): + """Create a robot config for a scoped left arm.""" return RobotModelConfig( name="left_arm", model_path=Path("/path/to/robot.urdf"), @@ -69,11 +75,6 @@ def robot_config_with_mapping(): joint_names=["joint1", "joint2", "joint3"], end_effector_link="link_tcp", base_link="link_base", - joint_name_mapping={ - "left/joint1": "joint1", - "left/joint2": "joint2", - "left/joint3": "joint3", - }, coordinator_task_name="traj_left", ) @@ -101,12 +102,11 @@ def __init__(self) -> None: self._error_message = "" self._planning_epoch = 0 self._robots = {} - self._planned_paths = {} - self._planned_trajectories = {} self._world_monitor = None self._planner = None self._kinematics = None self._coordinator_client = None + self._last_plan = None self.config = MagicMock(planning_timeout=10.0) @@ -165,19 +165,46 @@ def test_begin_planning_state_checks(self, robot_config): module = _make_module() module._world_monitor = MagicMock() module._robots = {"test_arm": ("robot_id", robot_config, MagicMock())} - # From IDLE - OK module._state = ManipulationState.IDLE - assert module._begin_planning() == ("test_arm", "robot_id") + assert module._begin_planning() is True assert module._state == ManipulationState.PLANNING # From COMPLETED - OK module._state = ManipulationState.COMPLETED - assert module._begin_planning() == ("test_arm", "robot_id") + assert module._begin_planning() is True # From EXECUTING - Fail module._state = ManipulationState.EXECUTING - assert module._begin_planning() is None + assert module._begin_planning() is False + + def test_empty_plan_targets_do_not_fault_before_planning(self): + """Pre-flight validation errors do not corrupt the planning state machine.""" + module = _make_module() + module._world_monitor = MagicMock() + module._kinematics = MagicMock() + module._planner = MagicMock() + + assert module.plan_to_pose_targets({}) is False + assert module._state == ManipulationState.IDLE + assert module._error_message == "" + + assert module.plan_to_joint_targets({}) is False + assert module._state == ManipulationState.IDLE + assert module._error_message == "" + + def test_plan_to_pose_targets_requires_planner_before_planning(self): + """Pose-target planning requires both IK and path planner backends.""" + module = _make_module() + module._world_monitor = MagicMock() + module._kinematics = MagicMock() + module._planner = None + + pose = Pose(position=Vector3(), orientation=Quaternion()) + + assert module.plan_to_pose_targets({"test_arm/manipulator": pose}) is False + assert module._state == ManipulationState.IDLE + assert module._error_message == "" class TestRobotSelection: @@ -300,11 +327,9 @@ def test_nested_kinematics_config_parses_cli_override_shape(self) -> None: assert config.kinematics.dt == 0.02 assert config.kinematics.posture_cost == 0.0 - def test_solve_ik_rpc_calls_configured_backend(self, robot_config): - """solve_ik returns the backend IKResult without path planning.""" - module = _make_module() - module._robots = {"test_arm": ("robot_id", robot_config, MagicMock())} - module._world_monitor = MagicMock() + def test_inverse_kinematics_single_calls_configured_backend(self, robot_config): + """inverse_kinematics_single returns the backend IKResult without path planning.""" + module = _make_module_with_monitor(robot_config) module._world_monitor.world = MagicMock() current = JointState(name=robot_config.joint_names, position=[0.0, 0.0, 0.0]) module._world_monitor.get_current_joint_state.return_value = current @@ -317,44 +342,42 @@ def test_solve_ik_rpc_calls_configured_backend(self, robot_config): message="ok", ) module._kinematics = MagicMock() - module._kinematics.solve.return_value = expected + module._kinematics.solve_pose_targets.return_value = expected pose = Pose(position=Vector3(x=0.45, y=0.0, z=0.25), orientation=Quaternion()) - result = module.solve_ik(pose) + result = module.inverse_kinematics_single(pose) assert result is expected - assert module._state == ManipulationState.COMPLETED - assert module._planned_paths == {} - module._kinematics.solve.assert_called_once() - _, kwargs = module._kinematics.solve.call_args + module._kinematics.solve_pose_targets.assert_called_once() + _, kwargs = module._kinematics.solve_pose_targets.call_args assert kwargs["world"] is module._world_monitor.world - assert kwargs["robot_id"] == "robot_id" - assert kwargs["seed"] is current - assert kwargs["check_collision"] is True - assert kwargs["target_pose"].frame_id == "world" - assert kwargs["target_pose"].position.x == 0.45 - - def test_solve_ik_rpc_returns_failure_without_joint_state(self, robot_config): - """solve_ik reports a failed IKResult when no seed state is available.""" - module = _make_module() - module._robots = {"test_arm": ("robot_id", robot_config, MagicMock())} - module._world_monitor = MagicMock() + assert kwargs["seed"].name == [ + "test_arm/joint1", + "test_arm/joint2", + "test_arm/joint3", + ] + assert kwargs["seed"].position == current.position + target_group, target_pose = next(iter(kwargs["pose_targets"].items())) + assert target_group.id == "test_arm/manipulator" + assert target_pose.frame_id == "world" + assert target_pose.position.x == 0.45 + + def test_inverse_kinematics_single_returns_failure_without_joint_state(self, robot_config): + """inverse_kinematics_single reports failure when no seed state is available.""" + module = _make_module_with_monitor(robot_config) module._world_monitor.get_current_joint_state.return_value = None module._kinematics = MagicMock() pose = Pose(position=Vector3(x=0.45, y=0.0, z=0.25), orientation=Quaternion()) - result = module.solve_ik(pose) + result = module.inverse_kinematics_single(pose) assert result.status == IKStatus.NO_SOLUTION assert result.message == "No joint state" - assert module._state == ManipulationState.IDLE - module._kinematics.solve.assert_not_called() + module._kinematics.solve_pose_targets.assert_not_called() - def test_solve_ik_rpc_uses_explicit_seed(self, robot_config): - """solve_ik initializes the backend from an explicit seed when provided.""" - module = _make_module() - module._robots = {"test_arm": ("robot_id", robot_config, MagicMock())} - module._world_monitor = MagicMock() + def test_inverse_kinematics_single_uses_explicit_seed(self, robot_config): + """inverse_kinematics_single initializes the backend from an explicit seed.""" + module = _make_module_with_monitor(robot_config) module._world_monitor.world = MagicMock() module._world_monitor.get_current_joint_state.return_value = JointState( name=robot_config.joint_names, position=[0.0, 0.0, 0.0] @@ -362,36 +385,48 @@ def test_solve_ik_rpc_uses_explicit_seed(self, robot_config): explicit_seed = JointState(name=robot_config.joint_names, position=[0.2, 0.1, 0.0]) expected = IKResult(status=IKStatus.SUCCESS, joint_state=explicit_seed) module._kinematics = MagicMock() - module._kinematics.solve.return_value = expected + module._kinematics.solve_pose_targets.return_value = expected pose = Pose(position=Vector3(x=0.45, y=0.0, z=0.25), orientation=Quaternion()) - result = module.solve_ik(pose, seed=explicit_seed) + result = module.inverse_kinematics_single(pose, seed=explicit_seed) assert result is expected - _, kwargs = module._kinematics.solve.call_args + _, kwargs = module._kinematics.solve_pose_targets.call_args assert kwargs["seed"] is explicit_seed module._world_monitor.get_current_joint_state.assert_not_called() + def test_forward_kinematics_accepts_extra_global_joints_and_requires_group_joints( + self, robot_config + ): + """forward_kinematics is group-centric and ignores non-group target joints.""" + module = _make_module_with_monitor(robot_config) + group = _make_global_group("test_arm", "wrist", ["joint1"]) + module._world_monitor.planning_groups = _FakePlanningGroups([group]) + module._world_monitor.get_state_monitor.return_value = MagicMock( + is_state_stale=lambda max_age: False + ) + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["joint1", "joint2", "joint3"], position=[0.0, 2.0, 3.0] + ) + pose = PoseStamped(position=Vector3(x=1.0), orientation=Quaternion()) + module._world_monitor.get_group_ee_pose.return_value = pose -class TestJointNameTranslation: - """Test trajectory joint name translation for coordinator.""" - - def test_no_mapping_returns_original(self, robot_config, simple_trajectory): - """Without mapping, trajectory is returned unchanged.""" - module = _make_module() - - result = module._translate_trajectory_to_coordinator(simple_trajectory, robot_config) - assert result is simple_trajectory # Same object + result = module.forward_kinematics( + "test_arm/wrist", + JointState(name=["test_arm/joint1", "test_arm/joint2"], position=[1.0, 9.0]), + ) - def test_mapping_translates_names(self, robot_config_with_mapping, simple_trajectory): - """With mapping, joint names are translated.""" - module = _make_module() + assert result.status == "VALID" + assert result.pose is pose + resolved_state = module._world_monitor.get_group_ee_pose.call_args.args[1] + assert resolved_state.name == ["joint1", "joint2", "joint3"] + assert resolved_state.position == [1.0, 9.0, 3.0] - result = module._translate_trajectory_to_coordinator( - simple_trajectory, robot_config_with_mapping + missing = module.forward_kinematics( + "test_arm/wrist", JointState(name=["test_arm/joint2"], position=[9.0]) ) - assert result.joint_names == ["left/joint1", "left/joint2", "left/joint3"] - assert len(result.points) == 2 # Points preserved + assert missing.status == "INVALID" + assert "missing group joints" in missing.message class TestExecute: @@ -401,13 +436,11 @@ def test_execute_requires_trajectory(self, robot_config): """Execute fails without planned trajectory.""" module = _make_module() module._robots = {"test_arm": ("id", robot_config, MagicMock())} - module._planned_trajectories = {} assert module.execute() is False def test_execute_requires_task_name(self): """Execute fails without coordinator_task_name.""" - module = _make_module() config_no_task = RobotModelConfig( name="arm", model_path=Path("/path"), @@ -415,16 +448,47 @@ def test_execute_requires_task_name(self): joint_names=["j1"], end_effector_link="ee", ) - module._robots = {"arm": ("id", config_no_task, MagicMock())} - module._planned_trajectories = {"arm": MagicMock()} + module = _make_module_with_monitor(config_no_task) + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("arm", "manipulator", ["j1"])] + ) + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["j1"], position=[0.0] + ) + module._last_plan = GeneratedPlan( + group_ids=("arm/manipulator",), + path=[JointState(name=["arm/j1"], position=[1.0])], + status=PlanningStatus.SUCCESS, + ) assert module.execute() is False def test_execute_success(self, robot_config, simple_trajectory): """Successful execute calls coordinator via task_invoke.""" - module = _make_module() - module._robots = {"test_arm": ("id", robot_config, MagicMock())} - module._planned_trajectories = {"test_arm": simple_trajectory} + module = _make_module_with_monitor(robot_config) + generator = MagicMock() + generator.generate.return_value = simple_trajectory + module._robots = {"test_arm": ("id", robot_config, generator)} + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("test_arm", "manipulator", ["joint1", "joint2", "joint3"])] + ) + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["joint1", "joint2", "joint3"], position=[0.0, 0.0, 0.0] + ) + module._last_plan = GeneratedPlan( + group_ids=("test_arm/manipulator",), + path=[ + JointState( + name=["test_arm/joint1", "test_arm/joint2", "test_arm/joint3"], + position=[0.0, 0.0, 0.0], + ), + JointState( + name=["test_arm/joint1", "test_arm/joint2", "test_arm/joint3"], + position=[0.5, 0.5, 0.5], + ), + ], + status=PlanningStatus.SUCCESS, + ) mock_client = MagicMock() mock_client.task_invoke.return_value = True @@ -432,15 +496,42 @@ def test_execute_success(self, robot_config, simple_trajectory): assert module.execute() is True assert module._state == ManipulationState.COMPLETED - mock_client.task_invoke.assert_called_once_with( - "traj_arm", "execute", {"trajectory": simple_trajectory} - ) + mock_client.task_invoke.assert_called_once() + assert mock_client.task_invoke.call_args.args[:2] == ("traj_arm", "execute") + trajectory = mock_client.task_invoke.call_args.args[2]["trajectory"] + assert trajectory.joint_names == [ + "test_arm/joint1", + "test_arm/joint2", + "test_arm/joint3", + ] + assert trajectory.points == simple_trajectory.points def test_execute_rejected(self, robot_config, simple_trajectory): """Rejected execution sets FAULT state.""" - module = _make_module() - module._robots = {"test_arm": ("id", robot_config, MagicMock())} - module._planned_trajectories = {"test_arm": simple_trajectory} + module = _make_module_with_monitor(robot_config) + generator = MagicMock() + generator.generate.return_value = simple_trajectory + module._robots = {"test_arm": ("id", robot_config, generator)} + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("test_arm", "manipulator", ["joint1", "joint2", "joint3"])] + ) + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["joint1", "joint2", "joint3"], position=[0.0, 0.0, 0.0] + ) + module._last_plan = GeneratedPlan( + group_ids=("test_arm/manipulator",), + path=[ + JointState( + name=["test_arm/joint1", "test_arm/joint2", "test_arm/joint3"], + position=[0.0, 0.0, 0.0], + ), + JointState( + name=["test_arm/joint1", "test_arm/joint2", "test_arm/joint3"], + position=[0.5, 0.5, 0.5], + ), + ], + status=PlanningStatus.SUCCESS, + ) mock_client = MagicMock() mock_client.task_invoke.return_value = False @@ -449,21 +540,47 @@ def test_execute_rejected(self, robot_config, simple_trajectory): assert module.execute() is False assert module._state == ManipulationState.FAULT + def test_execute_times_out_when_coordinator_rpc_does_not_respond( + self, robot_config, simple_trajectory + ): + """Coordinator RPC timeout fails execution instead of hanging silently.""" + module = _make_module_with_monitor(robot_config) + module.config.coordinator_rpc_timeout = 0.01 + generator = MagicMock() + generator.generate.return_value = simple_trajectory + module._robots = {"test_arm": ("id", robot_config, generator)} + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("test_arm", "manipulator", ["joint1", "joint2", "joint3"])] + ) + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["joint1", "joint2", "joint3"], position=[0.0, 0.0, 0.0] + ) + module._last_plan = GeneratedPlan( + group_ids=("test_arm/manipulator",), + path=[ + JointState( + name=["test_arm/joint1", "test_arm/joint2", "test_arm/joint3"], + position=[0.0, 0.0, 0.0], + ), + JointState( + name=["test_arm/joint1", "test_arm/joint2", "test_arm/joint3"], + position=[0.5, 0.5, 0.5], + ), + ], + status=PlanningStatus.SUCCESS, + ) + mock_client = MagicMock() + mock_client.remote_name = "ControlCoordinator" + mock_client._unsub_fns = [] + mock_client.rpc.call_sync.side_effect = TimeoutError("no response") + module._coordinator_client = mock_client -class TestRobotModelConfigMapping: - """Test RobotModelConfig joint name mapping helpers.""" - - def test_bidirectional_mapping(self, robot_config_with_mapping): - """Test URDF <-> coordinator name translation.""" - config = robot_config_with_mapping - - # Coordinator -> URDF - assert config.get_urdf_joint_name("left/joint1") == "joint1" - assert config.get_urdf_joint_name("unknown") == "unknown" + assert module.execute() is False - # URDF -> Coordinator - assert config.get_coordinator_joint_name("joint1") == "left/joint1" - assert config.get_coordinator_joint_name("unknown") == "unknown" + assert module._state == ManipulationState.FAULT + assert "timed out" in module._error_message + mock_client.rpc.call_sync.assert_called_once() + mock_client.task_invoke.assert_not_called() def _make_module_with_monitor(*configs: RobotModelConfig) -> ManipulationModule: @@ -474,6 +591,41 @@ def _make_module_with_monitor(*configs: RobotModelConfig) -> ManipulationModule: for config in configs: robot_id = f"robot_{config.name}" module._robots[config.name] = (robot_id, config, MagicMock()) + module._world_monitor.planning_groups = _FakePlanningGroups( + [ + _make_global_group(config.name, "manipulator", list(config.joint_names)) + for config in configs + ] + ) + module._world_monitor.is_state_valid.return_value = True + + def current_global_joint_state(max_age: float = 1.0) -> JointState | None: + del max_age + names: list[str] = [] + positions: list[float] = [] + for config in configs: + current = module._world_monitor.get_current_joint_state(f"robot_{config.name}") + if current is None: + return None + current_by_name = dict(zip(current.name, current.position, strict=True)) + for local_name in config.joint_names: + if local_name not in current_by_name: + return None + names.append(f"{config.name}/{local_name}") + positions.append(float(current_by_name[local_name])) + return JointState(name=names, position=positions) + + def check_collision(target_joints: JointState, max_age: float = 1.0) -> CollisionCheckResult: + del target_joints, max_age + collision_free = bool(module._world_monitor.is_state_valid.return_value) + return CollisionCheckResult( + status="VALID" if collision_free else "COLLISION", + collision_free=collision_free, + message="Target is collision-free" if collision_free else "Target is in collision", + ) + + module._world_monitor.current_global_joint_state.side_effect = current_global_joint_state + module._world_monitor.check_collision.side_effect = check_collision return module @@ -481,19 +633,84 @@ def _make_joint_state(positions: list[float], name: list[str] | None = None) -> return JointState(name=name or [f"j{i}" for i in range(len(positions))], position=positions) -def _make_path(*points: list[float]) -> list[JointState]: - return [_make_joint_state(list(point)) for point in points] +def _make_robot_config( + name: str, + joints: list[str], + task_name: str, +) -> RobotModelConfig: + return RobotModelConfig( + name=name, + model_path=Path("/path/to/robot.urdf"), + base_pose=PoseStamped(position=Vector3(), orientation=Quaternion()), + joint_names=joints, + end_effector_link="ee", + base_link="base", + coordinator_task_name=task_name, + ) -def _make_trajectory(*points: tuple[float, list[float]]) -> JointTrajectory: - joint_names = [f"j{i}" for i in range(len(points[0][1]))] if points else [] - return JointTrajectory( - joint_names=joint_names, +def _make_global_group(robot_name: str, group_name: str, joints: list[str]) -> PlanningGroup: + return PlanningGroup( + id=f"{robot_name}/{group_name}", + robot_name=robot_name, + group_name=group_name, + joint_names=tuple(f"{robot_name}/{joint}" for joint in joints), + local_joint_names=tuple(joints), + base_link="base", + tip_link="ee", + ) + + +class _FakePlanningGroups: + def __init__(self, groups: list[PlanningGroup]) -> None: + self._groups = {group.id: group for group in groups} + + def get(self, group_id: str) -> PlanningGroup: + return self._groups[group_id] + + def select(self, group_ids: tuple[str, ...]) -> PlanningGroupSelection: + return PlanningGroupSelection.from_groups( + tuple(self._groups[group_id] for group_id in group_ids) + ) + + def groups_for_robot(self, robot_name: str) -> tuple[PlanningGroup, ...]: + return tuple(group for group in self._groups.values() if group.robot_name == robot_name) + + def default_group_id_for_robot(self, robot_name: str) -> str | None: + group_id = f"{robot_name}/manipulator" + return group_id if group_id in self._groups else None + + def primary_pose_group_id_for_robot(self, robot_name: str) -> str | None: + for group in self.groups_for_robot(robot_name): + if group.has_pose_target: + return group.id + return None + + +def _make_generated_plan(group_ids: tuple[str, ...], *points: list[float]) -> GeneratedPlan: + return GeneratedPlan( + group_ids=group_ids, + path=[ + JointState( + name=["left/j1", "left/j2", "right/j1"], + position=list(point), + ) + for point in points + ], + status=PlanningStatus.SUCCESS, + ) + + +def _trajectory_generator() -> MagicMock: + generator = MagicMock() + generator.generate.side_effect = lambda positions: JointTrajectory( + joint_names=[], points=[ - TrajectoryPoint(time_from_start=time_from_start, positions=positions) - for time_from_start, positions in points + TrajectoryPoint(time_from_start=float(index), positions=list(position)) + for index, position in enumerate(positions) ], ) + return generator def _make_world_monitor_with_viz(viz: VisualizationSpec | None) -> WorldMonitor: @@ -508,9 +725,9 @@ class FakeVisualization: def __init__(self) -> None: self.close_count = 0 self.published = False - self.preview_shown: list[str] = [] - self.preview_hidden: list[str] = [] - self.animations: list[tuple[str, list[JointState], float]] = [] + self.preview_shown: list[tuple[str, ...]] = [] + self.preview_hidden: list[tuple[str, ...]] = [] + self.animations: list[tuple[GeneratedPlan, float]] = [] def initialize_scene(self, scene: PlanningSceneInfo) -> None: pass @@ -521,14 +738,14 @@ def get_visualization_url(self) -> str | None: def publish_visualization(self, ctx: object | None = None) -> None: self.published = True - def show_preview(self, robot_id: str) -> None: - self.preview_shown.append(robot_id) + def show_preview(self, group_ids: Sequence[str]) -> None: + self.preview_shown.append(tuple(group_ids)) - def hide_preview(self, robot_id: str) -> None: - self.preview_hidden.append(robot_id) + def hide_preview(self, group_ids: Sequence[str]) -> None: + self.preview_hidden.append(tuple(group_ids)) - def animate_path(self, robot_id: str, path: list[JointState], duration: float = 3.0) -> None: - self.animations.append((robot_id, path, duration)) + def animate_plan(self, plan: GeneratedPlan, duration: float = 3.0) -> None: + self.animations.append((plan, duration)) def close(self) -> None: self.close_count += 1 @@ -537,12 +754,12 @@ def close(self) -> None: class TestOnJointState: """Test _on_joint_state routing, splitting, and init capture.""" - def test_routes_positions_to_monitor(self, robot_config_with_mapping): + def test_routes_positions_to_monitor(self, left_robot_config): """Joint positions from aggregated message are routed to the correct monitor.""" - module = _make_module_with_monitor(robot_config_with_mapping) + module = _make_module_with_monitor(left_robot_config) msg = JointState( - name=["left/joint1", "left/joint2", "left/joint3"], + name=["left_arm/joint1", "left_arm/joint2", "left_arm/joint3"], position=[0.1, 0.2, 0.3], velocity=[1.0, 2.0, 3.0], ) @@ -552,13 +769,14 @@ def test_routes_positions_to_monitor(self, robot_config_with_mapping): module._world_monitor.on_joint_state.assert_called_once() call_args = module._world_monitor.on_joint_state.call_args sub_msg = call_args[0][0] + assert sub_msg.name == ["joint1", "joint2", "joint3"] assert sub_msg.position == [0.1, 0.2, 0.3] assert sub_msg.velocity == [1.0, 2.0, 3.0] assert call_args[1]["robot_id"] == "robot_left_arm" - def test_skips_robot_with_missing_joints(self, robot_config_with_mapping): + def test_skips_robot_with_missing_joints(self, left_robot_config): """Robots whose joints are absent from the message are skipped.""" - module = _make_module_with_monitor(robot_config_with_mapping) + module = _make_module_with_monitor(left_robot_config) # Message has none of left_arm's joints msg = JointState( @@ -569,12 +787,12 @@ def test_skips_robot_with_missing_joints(self, robot_config_with_mapping): module._world_monitor.on_joint_state.assert_not_called() - def test_captures_init_joints_on_first_call(self, robot_config_with_mapping): + def test_captures_init_joints_on_first_call(self, left_robot_config): """First joint state is stored as init joints; subsequent calls don't overwrite.""" - module = _make_module_with_monitor(robot_config_with_mapping) + module = _make_module_with_monitor(left_robot_config) first_msg = JointState( - name=["left/joint1", "left/joint2", "left/joint3"], + name=["left_arm/joint1", "left_arm/joint2", "left_arm/joint3"], position=[0.1, 0.2, 0.3], ) module._on_joint_state(first_msg) @@ -583,7 +801,7 @@ def test_captures_init_joints_on_first_call(self, robot_config_with_mapping): # Second call should NOT overwrite second_msg = JointState( - name=["left/joint1", "left/joint2", "left/joint3"], + name=["left_arm/joint1", "left_arm/joint2", "left_arm/joint3"], position=[0.9, 0.8, 0.7], ) module._on_joint_state(second_msg) @@ -598,7 +816,6 @@ def test_multi_robot_splits_correctly(self): joint_names=["j1", "j2"], end_effector_link="ee", base_link="base", - joint_name_mapping={"left/j1": "j1", "left/j2": "j2"}, coordinator_task_name="traj_left", ) right_config = RobotModelConfig( @@ -608,7 +825,6 @@ def test_multi_robot_splits_correctly(self): joint_names=["j1", "j2"], end_effector_link="ee", base_link="base", - joint_name_mapping={"right/j1": "j1", "right/j2": "j2"}, coordinator_task_name="traj_right", ) module = _make_module_with_monitor(left_config, right_config) @@ -632,15 +848,15 @@ def test_multi_robot_splits_correctly(self): assert calls["robot_left"].velocity == [0.1, 0.2] assert calls["robot_right"].velocity == [0.3, 0.4] - def test_no_monitor_returns_early(self, robot_config_with_mapping): + def test_no_monitor_returns_early(self, left_robot_config): """When world_monitor is None, _on_joint_state returns without error.""" module = _make_module() - module._robots = {"left_arm": ("id", robot_config_with_mapping, MagicMock())} + module._robots = {"left_arm": ("id", left_robot_config, MagicMock())} module._world_monitor = None # Should not raise msg = JointState( - name=["left/joint1", "left/joint2", "left/joint3"], + name=["left_arm/joint1", "left_arm/joint2", "left_arm/joint3"], position=[0.1, 0.2, 0.3], ) module._on_joint_state(msg) @@ -659,15 +875,20 @@ def test_visualization_routing_and_stop_all_monitors(self): assert monitor.get_visualization_url() == "123" monitor.publish_visualization() - monitor.show_preview("robot") - monitor.hide_preview("robot") - path = _make_path([1.0], [2.0], [3.0]) - monitor.animate_path("robot", path, 4.5) + group_ids = ("robot/manipulator",) + plan = GeneratedPlan( + group_ids=group_ids, + path=[JointState(name=["robot/j1"], position=[0.0])], + status=PlanningStatus.SUCCESS, + ) + monitor.show_preview(group_ids) + monitor.hide_preview(group_ids) + monitor.animate_plan(plan, 4.5) assert monitor.visualization is viz assert viz.published is True - assert viz.preview_shown == ["robot"] - assert viz.preview_hidden == ["robot"] - assert viz.animations == [("robot", path, 4.5)] + assert viz.preview_shown == [group_ids] + assert viz.preview_hidden == [group_ids] + assert viz.animations == [(plan, 4.5)] monitor.stop_all_monitors() @@ -680,9 +901,14 @@ def test_visualization_none_is_noop(self): assert monitor.get_visualization_url() is None monitor.publish_visualization() - monitor.show_preview("robot") - monitor.hide_preview("robot") - monitor.animate_path("robot", [1], 1.0) + plan = GeneratedPlan( + group_ids=("robot/manipulator",), + path=[JointState(name=["robot/j1"], position=[0.0])], + status=PlanningStatus.SUCCESS, + ) + monitor.show_preview(("robot/manipulator",)) + monitor.hide_preview(("robot/manipulator",)) + monitor.animate_plan(plan, 1.0) monitor.start_visualization_thread() assert monitor._viz_thread is None @@ -691,90 +917,233 @@ class TestManipulationPreview: def test_dismiss_preview_noop_without_monitor(self): module = _make_module() - module._dismiss_preview("robot_id") + module._dismiss_preview(("arm/manipulator",)) def test_dismiss_preview_routes_to_monitor(self): module = _make_module() module._world_monitor = MagicMock() - module._dismiss_preview("robot_id") + group_ids = ("arm/manipulator",) + module._dismiss_preview(group_ids) - module._world_monitor.hide_preview.assert_called_once_with("robot_id") + module._world_monitor.hide_preview.assert_called_once_with(group_ids) module._world_monitor.publish_visualization.assert_called_once_with() - def test_preview_path_uses_trajectory_duration_and_interpolates(self): + def test_preview_plan_uses_safe_default_duration(self): module = _make_module() module._world_monitor = MagicMock() - module._robots = {"arm": ("robot_id", MagicMock(), MagicMock())} - module._planned_paths = {"arm": _make_path([0.0], [2.0])} - module._planned_trajectories = {"arm": _make_trajectory((0.0, [0.0]), (2.0, [2.0]))} + module._last_plan = GeneratedPlan( + group_ids=("arm/manipulator",), + path=[JointState(name=["arm/j1"], position=[0.0])], + status=PlanningStatus.SUCCESS, + ) - assert module.preview_path(robot_name="arm", target_fps=2.0) is True + assert module.preview_plan() is True - module._world_monitor.animate_path.assert_called_once() - robot_id, preview_path, duration = module._world_monitor.animate_path.call_args.args - assert robot_id == "robot_id" - assert duration == 2.0 - assert [state.position for state in preview_path] == [[0.0], [0.5], [1.0], [1.5], [2.0]] + module._world_monitor.animate_plan.assert_called_once_with(module._last_plan, 1.0) - def test_preview_path_explicit_duration_overrides_and_fps_densifies(self): + def test_preview_plan_explicit_duration_overrides_default(self): module = _make_module() module._world_monitor = MagicMock() - module._robots = {"arm": ("robot_id", MagicMock(), MagicMock())} - module._planned_paths = {"arm": _make_path([0.0], [9.0])} - module._planned_trajectories = {"arm": _make_trajectory((0.0, [0.0]), (9.0, [9.0]))} + module._last_plan = GeneratedPlan( + group_ids=("arm/manipulator",), + path=[JointState(name=["arm/j1"], position=[0.0])], + status=PlanningStatus.SUCCESS, + ) - assert module.preview_path(duration=1.5, robot_name="arm", target_fps=2.0) is True + assert module.preview_plan(duration=1.5) is True - module._world_monitor.animate_path.assert_called_once() - robot_id, preview_path, duration = module._world_monitor.animate_path.call_args.args - assert robot_id == "robot_id" - assert duration == 1.5 - assert [state.position for state in preview_path] == [[0.0], [3.0], [6.0], [9.0]] + module._world_monitor.animate_plan.assert_called_once_with(module._last_plan, 1.5) - def test_preview_path_missing_trajectory_uses_default_duration(self): + def test_preview_plan_respects_robot_filter(self): module = _make_module() module._world_monitor = MagicMock() - module._robots = {"arm": ("robot_id", MagicMock(), MagicMock())} - module._planned_paths = {"arm": _make_path([0.0], [1.0])} - module._planned_trajectories = {} + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("arm", "manipulator", ["j1"])] + ) + module._last_plan = GeneratedPlan( + group_ids=("arm/manipulator",), + path=[JointState(name=["arm/j1"], position=[0.0])], + status=PlanningStatus.SUCCESS, + ) + + assert module.preview_plan(robot_name="arm") is True - assert module.preview_path(robot_name="arm", target_fps=10.0) is True + module._world_monitor.animate_plan.assert_called_once_with(module._last_plan, 1.0) - module._world_monitor.animate_path.assert_called_once_with( - "robot_id", module._planned_paths["arm"], 3.0 + def test_preview_plan_rejects_unaffected_robot_filter(self): + module = _make_module() + module._world_monitor = MagicMock() + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("arm", "manipulator", ["j1"])] ) + module._last_plan = GeneratedPlan( + group_ids=("arm/manipulator",), + path=[JointState(name=["arm/j1"], position=[0.0])], + status=PlanningStatus.SUCCESS, + ) + + assert module.preview_plan(robot_name="other") is False - def test_preview_path_skips_interpolation_for_nonpositive_fps_or_duration(self): + module._world_monitor.animate_plan.assert_not_called() + + def test_preview_plan_returns_false_for_missing_inputs(self): module = _make_module() + + assert module.preview_plan() is False + module._world_monitor = MagicMock() - module._robots = {"arm": ("robot_id", MagicMock(), MagicMock())} - module._planned_paths = {"arm": _make_path([0.0], [1.0])} - module._planned_trajectories = {"arm": _make_trajectory((0.0, [0.0]), (2.0, [1.0]))} + assert module.preview_plan() is False - assert module.preview_path(robot_name="arm", target_fps=0.0) is True - assert module.preview_path(duration=0.0, robot_name="arm", target_fps=20.0) is True - assert ( - module._world_monitor.animate_path.call_args_list[0].args[1] - == module._planned_paths["arm"] +class TestGeneratedPlanProjection: + def test_selected_joint_state_accepts_local_current_state_names(self): + config = _make_robot_config("left", ["j1", "j2"], "task") + module = _make_module_with_monitor(config) + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("left", "arm", ["j1", "j2"])] ) - assert ( - module._world_monitor.animate_path.call_args_list[1].args[1] - == module._planned_paths["arm"] + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["j1", "j2"], position=[1.0, 2.0] ) - def test_preview_path_returns_false_for_missing_inputs(self): - module = _make_module() - module._planned_paths = {"arm": _make_path([0.0], [1.0])} - module._robots = {"arm": ("robot_id", MagicMock(), MagicMock())} + selected = module._selected_joint_state(("left/arm",)) - assert module.preview_path(robot_name="arm") is False + assert selected is not None + assert selected.name == ["left/j1", "left/j2"] + assert selected.position == [1.0, 2.0] - module._world_monitor = MagicMock() - module._robots = {} - assert module.preview_path(robot_name="arm") is False + def test_selected_joint_state_rejects_mixed_current_state_names(self): + config = _make_robot_config("left", ["j1", "j2"], "task") + module = _make_module_with_monitor(config) + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("left", "arm", ["j1", "j2"])] + ) + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["left/j1", "j2"], position=[1.0, 2.0] + ) - module._robots = {"arm": ("robot_id", MagicMock(), MagicMock())} - module._planned_paths = {"arm": []} - assert module.preview_path(robot_name="arm") is False + assert module._selected_joint_state(("left/arm",)) is None + + def test_execute_plan_dispatches_one_trajectory_per_affected_robot(self): + left_config = _make_robot_config( + "left", + ["j1", "j2", "j3"], + "left_task", + ) + right_config = _make_robot_config("right", ["j1", "j2"], "right_task") + module = _make_module_with_monitor(left_config, right_config) + left_gen = _trajectory_generator() + right_gen = _trajectory_generator() + module._robots["left"] = ("robot_left", left_config, left_gen) + module._robots["right"] = ("robot_right", right_config, right_gen) + module._world_monitor.planning_groups = _FakePlanningGroups( + [ + _make_global_group("left", "arm", ["j1", "j2"]), + _make_global_group("right", "arm", ["j1"]), + ] + ) + module._world_monitor.get_current_joint_state.side_effect = [ + JointState(name=["j1", "j2", "j3"], position=[0.0, 0.0, 9.0]), + JointState(name=["j1", "j2"], position=[0.0, 8.0]), + ] + module._coordinator_client = MagicMock() + module._coordinator_client.task_invoke.return_value = True + plan = _make_generated_plan(("left/arm", "right/arm"), [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]) + + assert module.execute_plan(plan) is True + + assert module._coordinator_client.task_invoke.call_count == 2 + left_call, right_call = module._coordinator_client.task_invoke.call_args_list + assert left_call.args[0:2] == ("left_task", "execute") + left_trajectory = left_call.args[2]["trajectory"] + assert left_trajectory.joint_names == ["left/j1", "left/j2", "left/j3"] + assert [point.positions for point in left_trajectory.points] == [ + [1.0, 2.0, 9.0], + [4.0, 5.0, 9.0], + ] + assert right_call.args[0:2] == ("right_task", "execute") + right_trajectory = right_call.args[2]["trajectory"] + assert right_trajectory.joint_names == ["right/j1", "right/j2"] + assert [point.positions for point in right_trajectory.points] == [[3.0, 8.0], [6.0, 8.0]] + + def test_execute_plan_holds_non_selected_joints_from_current_state(self): + config = _make_robot_config("left", ["j1", "j2", "j3"], "task") + module = _make_module_with_monitor(config) + generator = _trajectory_generator() + module._robots["left"] = ("robot_left", config, generator) + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("left", "arm", ["j2"])] + ) + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["j1", "j2", "j3"], position=[10.0, 20.0, 30.0] + ) + module._coordinator_client = MagicMock() + module._coordinator_client.task_invoke.return_value = True + plan = GeneratedPlan( + group_ids=("left/arm",), + path=[ + JointState(name=["left/j2"], position=[2.0]), + JointState(name=["left/j2"], position=[3.0]), + ], + status=PlanningStatus.SUCCESS, + ) + + assert module.execute_plan(plan) is True + + trajectory = module._coordinator_client.task_invoke.call_args.args[2]["trajectory"] + assert trajectory.joint_names == ["left/j1", "left/j2", "left/j3"] + assert [point.positions for point in trajectory.points] == [ + [10.0, 2.0, 30.0], + [10.0, 3.0, 30.0], + ] + + def test_execute_plan_rejects_local_waypoint_names(self): + config = _make_robot_config("left", ["j1", "j2"], "task") + module = _make_module_with_monitor(config) + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("left", "arm", ["j1"])] + ) + module._world_monitor.get_current_joint_state.return_value = JointState( + name=["j1", "j2"], position=[10.0, 20.0] + ) + module._coordinator_client = MagicMock() + plan = GeneratedPlan( + group_ids=("left/arm",), + path=[JointState(name=["j1"], position=[1.0])], + status=PlanningStatus.SUCCESS, + ) + + assert module.execute_plan(plan) is False + module._coordinator_client.task_invoke.assert_not_called() + + def test_preview_plan_with_last_plan_animates_generated_plan(self): + config = _make_robot_config("left", ["j1", "j2"], "task") + module = _make_module_with_monitor(config) + module._world_monitor.planning_groups = _FakePlanningGroups( + [_make_global_group("left", "arm", ["j1"])] + ) + module._last_plan = GeneratedPlan( + group_ids=("left/arm",), + path=[ + JointState(name=["left/j1"], position=[1.0]), + JointState(name=["left/j1"], position=[2.0]), + ], + status=PlanningStatus.SUCCESS, + ) + + assert module.preview_plan(robot_name="left") is True + + module._world_monitor.animate_plan.assert_called_once_with(module._last_plan, 1.0) + + def test_has_and_clear_planned_path_use_last_plan(self): + module = _make_module() + module._last_plan = GeneratedPlan( + group_ids=("left/arm",), + path=[JointState(name=["left/j1"], position=[1.0])], + status=PlanningStatus.SUCCESS, + ) + assert module.has_planned_path() is True + assert module.clear_planned_path() is True + assert module.has_planned_path() is False + assert module._last_plan is None diff --git a/dimos/manipulation/visualization/test_factory.py b/dimos/manipulation/visualization/test_factory.py index 4d49f019cf..3f9e22d9d9 100644 --- a/dimos/manipulation/visualization/test_factory.py +++ b/dimos/manipulation/visualization/test_factory.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Sequence from contextlib import AbstractContextManager, nullcontext from pathlib import Path from unittest.mock import MagicMock @@ -26,8 +27,9 @@ from dimos.manipulation.manipulation_module import ManipulationModuleConfig from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.manipulation.planning.spec.models import ( - JointPath, + GeneratedPlan, Obstacle, + PlanningGroupID, PlanningSceneInfo, WorldRobotID, ) @@ -52,13 +54,13 @@ def get_visualization_url(self) -> str | None: def publish_visualization(self, ctx: object | None = None) -> None: return None - def show_preview(self, robot_id: WorldRobotID) -> None: + def show_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: return None - def hide_preview(self, robot_id: WorldRobotID) -> None: + def hide_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: return None - def animate_path(self, robot_id: WorldRobotID, path: JointPath, duration: float = 3.0) -> None: + def animate_plan(self, plan: GeneratedPlan, duration: float = 3.0) -> None: return None def close(self) -> None: diff --git a/dimos/manipulation/visualization/types.py b/dimos/manipulation/visualization/types.py index 778c293499..8122643ecd 100644 --- a/dimos/manipulation/visualization/types.py +++ b/dimos/manipulation/visualization/types.py @@ -16,7 +16,8 @@ from typing import TypedDict -from dimos.manipulation.planning.spec.models import RobotName, WorldRobotID +from dimos.manipulation.planning.groups.models import PlanningGroup +from dimos.manipulation.planning.spec.models import PlanningGroupID, RobotName, WorldRobotID from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.sensor_msgs.JointState import JointState @@ -33,11 +34,24 @@ class TargetEvaluation(TypedDict, total=False): orientation_error: float -class RobotInfo(TypedDict): +class TargetSetEvaluation(TypedDict, total=False): + success: bool + status: str + message: str + collision_free: bool + group_ids: tuple[PlanningGroupID, ...] + target_joints: JointState | None + group_diagnostics: dict[PlanningGroupID, str] + group_poses: dict[PlanningGroupID, PoseStamped | Pose | None] + position_error: float + orientation_error: float + + +class RobotInfo(TypedDict, total=False): name: RobotName world_robot_id: WorldRobotID joint_names: list[str] - end_effector_link: str + end_effector_link: str | None base_link: str max_velocity: float max_acceleration: float @@ -46,3 +60,4 @@ class RobotInfo(TypedDict): home_joints: list[float] | None pre_grasp_offset: float init_joints: list[float] | None + planning_groups: list[PlanningGroup] diff --git a/dimos/manipulation/visualization/viser/adapter.py b/dimos/manipulation/visualization/viser/adapter.py index c3eb8e360b..48b84c6f55 100644 --- a/dimos/manipulation/visualization/viser/adapter.py +++ b/dimos/manipulation/visualization/viser/adapter.py @@ -15,8 +15,9 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast +from dimos.manipulation.planning.spec.enums import IKStatus from dimos.manipulation.visualization.types import RobotInfo, TargetEvaluation from dimos.msgs.sensor_msgs.JointState import JointState @@ -65,26 +66,7 @@ def get_robot_info(self, robot_name: RobotName) -> RobotInfo | None: info = self._module.get_robot_info(robot_name) if info is None: return None - return { - "name": str(info["name"]), - "world_robot_id": str(info["world_robot_id"]), - "joint_names": [str(name) for name in info["joint_names"]], - "end_effector_link": str(info["end_effector_link"]), - "base_link": str(info["base_link"]), - "max_velocity": float(info["max_velocity"]), - "max_acceleration": float(info["max_acceleration"]), - "has_joint_name_mapping": bool(info["has_joint_name_mapping"]), - "coordinator_task_name": None - if info["coordinator_task_name"] is None - else str(info["coordinator_task_name"]), - "home_joints": None - if info["home_joints"] is None - else [float(value) for value in info["home_joints"]], - "pre_grasp_offset": float(info["pre_grasp_offset"]), - "init_joints": None - if info["init_joints"] is None - else [float(value) for value in info["init_joints"]], - } + return info def get_init_joints(self, robot_name: RobotName) -> JointState | None: return copy_joint_state(self._module.get_init_joints(robot_name)) @@ -108,34 +90,112 @@ def get_ee_pose( return self._world_monitor.get_ee_pose(robot_id, copy_joint_state(joint_state)) def evaluate_joint_target(self, joints: JointState, robot_name: RobotName) -> TargetEvaluation: - """Evaluate a joint target through WorldMonitor helpers, not raw WorldSpec access.""" - result: TargetEvaluation = { - **self._module.evaluate_joint_target(copy_joint_state(joints), robot_name) + """Evaluate a legacy robot-scoped joint target through group-scoped APIs.""" + legacy_evaluate = getattr(self._module, "evaluate_joint_target", None) + if callable(legacy_evaluate): + result = cast("TargetEvaluation", legacy_evaluate(JointState(joints), robot_name)) + joint_state = result.get("joint_state") + result["joint_state"] = copy_joint_state( + joint_state if isinstance(joint_state, JointState) else None + ) + return result + default_group_id = getattr(self._module, "_default_group_id_for_robot", None) + joint_target_to_global = getattr(self._module, "_joint_target_to_global_names", None) + if not callable(default_group_id) or not callable(joint_target_to_global): + robot_id = self.robot_id_for_name(robot_name) + valid = ( + False + if robot_id is None + else bool(self._world_monitor.is_state_valid(robot_id, JointState(joints))) + ) + return { + "success": valid, + "status": "FEASIBLE" if valid else "COLLISION", + "message": "Target is valid" if valid else "Target is in collision", + "collision_free": valid, + "joint_state": JointState(joints), + } + group_id = self._module._default_group_id_for_robot(robot_name) + if group_id is None: + return {"success": False, "status": "INVALID", "message": "No default group"} + target = self._module._joint_target_to_global_names(group_id, JointState(joints)) + if target is None: + return {"success": False, "status": "INVALID", "message": "Invalid joint target"} + collision = self._module.check_collision(target) + collision_free = collision.collision_free is True + return { + "success": collision_free, + "status": collision.status, + "message": collision.message, + "collision_free": collision_free, + "joint_state": copy_joint_state(target), } - joint_state = result.get("joint_state") - result["joint_state"] = copy_joint_state( - joint_state if isinstance(joint_state, JointState) else None - ) - return result def evaluate_pose_target(self, pose: Pose, robot_name: RobotName) -> TargetEvaluation: - """Evaluate a Cartesian target through module/WorldMonitor helper boundaries.""" - result: TargetEvaluation = {**self._module.evaluate_pose_target(pose, robot_name)} - joint_state = result.get("joint_state") - result["joint_state"] = copy_joint_state( - joint_state if isinstance(joint_state, JointState) else None - ) - return result + """Evaluate a legacy robot-scoped pose target through group-scoped IK.""" + ik = self._module.inverse_kinematics_single(pose, robot_name=robot_name) + if not ik.is_success() or ik.joint_state is None: + return { + "success": False, + "status": ik.status.name, + "message": ik.message, + "collision_free": False, + "joint_state": None, + "position_error": ik.position_error, + "orientation_error": ik.orientation_error, + } + collision = self._module.check_collision(ik.joint_state) + collision_free = collision.collision_free is True + return { + "success": collision_free, + "status": collision.status if not collision_free else IKStatus.SUCCESS.name, + "message": collision.message, + "collision_free": collision_free, + "joint_state": copy_joint_state(ik.joint_state), + "position_error": ik.position_error, + "orientation_error": ik.orientation_error, + } def get_planned_path(self, robot_name: RobotName) -> JointPath | None: - path = self._module.get_planned_path(robot_name) - if path is None: + legacy_get_planned_path = getattr(self._module, "get_planned_path", None) + if callable(legacy_get_planned_path): + legacy_path = legacy_get_planned_path(robot_name) + if legacy_path is None: + return None + if not isinstance(legacy_path, list): + return None + copied = [copy_joint_state(point) for point in legacy_path] + return [point for point in copied if point is not None] + planned_paths = getattr(self._module, "_planned_paths", None) + if isinstance(planned_paths, dict): + path_obj = planned_paths.get(robot_name) + if isinstance(path_obj, list): + copied = [copy_joint_state(point) for point in path_obj] + return [point for point in copied if point is not None] + plan = getattr(self._module, "_last_plan", None) + config = self.get_robot_config(robot_name) + current = self.get_current_joint_state(robot_name) + if plan is None or config is None or current is None: return None - copied = [copy_joint_state(point) for point in path] - return [point for point in copied if point is not None] + path: JointPath = [] + current_by_name = dict(zip(current.name, current.position, strict=False)) + for waypoint in plan.path: + selected = dict(zip(waypoint.name, waypoint.position, strict=False)) + positions: list[float] = [] + for local_name in config.joint_names: + global_name = f"{robot_name}/{local_name}" + if global_name in selected: + positions.append(float(selected[global_name])) + elif local_name in current_by_name: + positions.append(float(current_by_name[local_name])) + else: + return None + path.append(JointState(name=list(config.joint_names), position=positions)) + return path def get_planned_trajectory_duration(self, robot_name: RobotName) -> float | None: - return self._module.get_planned_trajectory_duration(robot_name) + path = self.get_planned_path(robot_name) + return None if path is None else float(max(len(path) - 1, 0)) def get_module_state(self) -> str: return str(self._module.get_state()) @@ -152,11 +212,18 @@ def plan_to_pose(self, pose: Pose, robot_name: RobotName | None = None) -> bool: def plan_to_joints(self, joints: JointState, robot_name: RobotName | None = None) -> bool: return self._module.plan_to_joints(joints, robot_name) - def preview_path(self, robot_name: RobotName | None = None) -> bool: - return self._module.preview_path(robot_name=robot_name) + def preview_path(self, robot_name: RobotName | None = None) -> object: + preview_path = getattr(self._module, "preview_path", None) + if callable(preview_path): + return preview_path(robot_name=robot_name) + return self._module.preview_plan(robot_name=robot_name) def execute(self, robot_name: RobotName | None = None) -> bool: - return self._module.execute(robot_name) + execute_plan = getattr(self._module, "execute_plan", None) + if callable(execute_plan): + return bool(execute_plan()) + execute = getattr(self._module, "execute", None) + return bool(execute(robot_name)) if callable(execute) else False def cancel(self) -> bool: return self._module.cancel() diff --git a/dimos/manipulation/visualization/viser/visualizer.py b/dimos/manipulation/visualization/viser/visualizer.py index d12c98e53b..2a44008118 100644 --- a/dimos/manipulation/visualization/viser/visualizer.py +++ b/dimos/manipulation/visualization/viser/visualizer.py @@ -14,9 +14,11 @@ from __future__ import annotations +from collections.abc import Sequence from contextlib import suppress from typing import TYPE_CHECKING +from dimos.manipulation.planning.groups.identifiers import make_global_joint_name from dimos.manipulation.visualization.viser.adapter import InProcessViserAdapter from dimos.manipulation.visualization.viser.config import ViserVisualizationConfig from dimos.manipulation.visualization.viser.gui import ViserPanelGui @@ -27,6 +29,7 @@ ) from dimos.manipulation.visualization.viser.scene import ViserManipulationScene from dimos.manipulation.visualization.viser.theme import apply_dimos_theme +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger try: @@ -43,10 +46,11 @@ if TYPE_CHECKING: from dimos.manipulation.manipulation_module import ManipulationModule from dimos.manipulation.planning.monitor.world_monitor import WorldMonitor + from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.manipulation.planning.spec.models import ( - JointPath, + GeneratedPlan, + PlanningGroupID, PlanningSceneInfo, - WorldRobotID, ) logger = setup_logger() @@ -81,17 +85,22 @@ def _ensure_started(self) -> None: try: server = runtime.start() apply_dimos_theme(server) - adapter = InProcessViserAdapter( - world_monitor=self._world_monitor, - manipulation_module=self._manipulation_module, - ) scene = ViserManipulationScene( server, ViserUrdf, preview_fps=self.config.preview_fps, ) + adapter = InProcessViserAdapter( + world_monitor=self._world_monitor, + manipulation_module=self._manipulation_module, + ) gui = ( - ViserPanelGui(server, adapter, self.config, scene) + ViserPanelGui( + server, + adapter, + self.config, + scene, + ) if self.config.panel_enabled else None ) @@ -147,38 +156,124 @@ def publish_visualization(self, ctx: None = None) -> None: self._ensure_started() if self._adapter is None or self._scene is None: return - for _robot_name, robot_id, _config in self._adapter.robot_items(): - current = self._adapter.get_current_joint_state(_robot_name) + for robot_name, robot_id, _config in self._adapter.robot_items(): + current = self._adapter.get_current_joint_state(robot_name) self._scene.update_current_robot(str(robot_id), current) if self._gui is not None: self._gui.refresh() - def show_preview(self, robot_id: WorldRobotID) -> None: + def show_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: if not self._closed: self._ensure_started() if self._scene is None: return - self._scene.show_preview(str(robot_id)) + for robot_id in self._robot_ids_for_groups(group_ids): + self._scene.show_preview(str(robot_id)) - def hide_preview(self, robot_id: WorldRobotID) -> None: + def hide_preview(self, group_ids: Sequence[PlanningGroupID]) -> None: if not self._closed: self._ensure_started() if self._scene is None: return - self._scene.hide_preview(str(robot_id)) + for robot_id in self._robot_ids_for_groups(group_ids): + self._scene.hide_preview(str(robot_id)) def animate_path( self, - robot_id: WorldRobotID, - path: JointPath, + robot_id: str, + path: list[JointState], duration: float = 3.0, ) -> None: + """Compatibility wrapper for legacy robot-scoped Viser callers.""" + if self._closed: + return + self._ensure_started() + if self._scene is not None: + self._scene.animate_path(str(robot_id), path, duration) + + def animate_plan(self, plan: GeneratedPlan, duration: float = 3.0) -> None: if self._closed: return self._ensure_started() if self._scene is None: return - self._scene.animate_path(str(robot_id), list(path), duration) + for robot_name in self._robot_names_for_groups(plan.group_ids): + robot_id = self._manipulation_module.robot_id_for_name(robot_name) + config = self._manipulation_module.get_robot_config(robot_name) + current = self._manipulation_module.get_current_joint_state(robot_name) + if robot_id is None or config is None or current is None: + logger.warning( + "Cannot build group preview for robot '%s': missing id, config, or state", + robot_name, + ) + return + path = self._robot_path_for_plan(robot_name, config, current, plan) + if not path: + logger.warning("Cannot project generated plan for robot '%s'", robot_name) + return + self._scene.animate_path(str(robot_id), path, duration) + + def _robot_names_for_groups(self, group_ids: Sequence[PlanningGroupID]) -> list[str]: + selection = self._world_monitor.planning_groups.select(group_ids) + return list(selection.robot_names) + + def _robot_ids_for_groups(self, group_ids: Sequence[PlanningGroupID]) -> list[str]: + if isinstance(group_ids, str): + return [group_ids] + if not hasattr(self._world_monitor, "planning_groups"): + return [str(group_id) for group_id in group_ids] + selection = self._world_monitor.planning_groups.select(group_ids) + robot_ids: list[str] = [] + for robot_name in selection.robot_names: + robot_id = self._manipulation_module.robot_id_for_name(robot_name) + if robot_id is not None: + robot_ids.append(str(robot_id)) + return robot_ids + + def _robot_path_for_plan( + self, + robot_name: str, + config: RobotModelConfig, + current: JointState, + plan: GeneratedPlan, + ) -> list[JointState]: + current_by_name = self._current_positions_by_name(config, current) + if current_by_name is None: + return [] + path: list[JointState] = [] + for waypoint in plan.path: + if len(waypoint.name) != len(waypoint.position): + return [] + selected = dict(zip(waypoint.name, waypoint.position, strict=True)) + positions: list[float] = [] + for local_name in config.joint_names: + global_name = make_global_joint_name(robot_name, local_name) + if global_name in selected: + positions.append(float(selected[global_name])) + continue + if local_name not in current_by_name: + return [] + positions.append(current_by_name[local_name]) + path.append(JointState(name=list(config.joint_names), position=positions)) + return path + + @staticmethod + def _current_positions_by_name( + config: RobotModelConfig, current: JointState + ) -> dict[str, float] | None: + if current.name: + if len(current.name) != len(current.position): + return None + return { + str(name): float(position) + for name, position in zip(current.name, current.position, strict=True) + } + if len(current.position) != len(config.joint_names): + return None + return { + str(name): float(position) + for name, position in zip(config.joint_names, current.position, strict=True) + } def close(self) -> None: if self._closed: diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 5ba30bdd8c..f10d2a5d76 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -59,6 +59,7 @@ "drone-agentic": "dimos.robot.drone.blueprints.agentic.drone_agentic:drone_agentic", "drone-basic": "dimos.robot.drone.blueprints.basic.drone_basic:drone_basic", "dual-xarm6-planner": "dimos.robot.manipulators.xarm.blueprints.basic:dual_xarm6_planner", + "dual-xarm6-planner-coordinator": "dimos.robot.manipulators.xarm.blueprints.basic:dual_xarm6_planner_coordinator", "keyboard-teleop-a750": "dimos.robot.manipulators.a750.blueprints.teleop:keyboard_teleop_a750", "keyboard-teleop-openarm": "dimos.robot.manipulators.openarm.blueprints.teleop:keyboard_teleop_openarm", "keyboard-teleop-openarm-mock": "dimos.robot.manipulators.openarm.blueprints.teleop:keyboard_teleop_openarm_mock", diff --git a/dimos/robot/config.py b/dimos/robot/config.py new file mode 100644 index 0000000000..42da6d56ae --- /dev/null +++ b/dimos/robot/config.py @@ -0,0 +1,287 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified robot configuration. + +Single source of truth for a robot. The URDF/MJCF model file is the +ground truth — joint names, DOF, limits, and link hierarchy are parsed +automatically. Generates RobotModelConfig, HardwareComponent, and TaskConfig. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, Field, PrivateAttr + +from dimos.control.components import HardwareComponent, HardwareType +from dimos.control.coordinator import TaskConfig +from dimos.manipulation.planning.groups.discovery import discover_planning_group_definitions +from dimos.manipulation.planning.groups.identifiers import make_global_joint_names +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.robot.model_parser import ModelDescription, parse_model + + +class GripperConfig(BaseModel): + """Gripper configuration.""" + + type: str + joints: list[str] = Field(default_factory=list) + collision_exclusions: list[tuple[str, str]] = Field(default_factory=list) + open_position: float = 1.0 + close_position: float = 0.0 + + +class RobotConfig(BaseModel): + """Unified robot configuration — URDF/MJCF is the ground truth. + + Model parsing is lazy to avoid LFS downloads at import time. + """ + + # Required fields + name: str + model_path: Path | None = None + # Compatibility robot-scoped target frame; new planning uses group tip links. + end_effector_link: str | None = None + + # Physical dimensions (meters) + height_clearance: float | None = None # max height + width_clearance: float | None = None # max width + + # These offsets are applied so that odometry at 0,0,0 corresponds roughly with the floor + # Note: these cannot (easily) be calculated from the URDF because + # the URDF doesn't always have an initial robot pose/stance + # This is a quality of life offset, not exact + # The key names should match keys in the urdf + internal_odom_offsets: dict[str, Any] = Field(default_factory=dict) + + # Hardware connection + adapter_type: str = "mock" + address: str | None = None + adapter_kwargs: dict[str, Any] = Field(default_factory=dict) + auto_enable: bool = True + + # Optional overrides (derived from model if not set) + joint_names: list[str] | None = None + base_link: str | None = ( + None # Compatibility planning override; derived from model root when absent. + ) + home_joints: list[float] | None = None + + # Canonical planning placement. Robot models should describe intrinsic geometry; + # instance placement belongs here. + base_pose: list[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) + strip_model_world_joint: bool = False + + # Planning + max_velocity: float = 1.0 + max_acceleration: float = 2.0 + pre_grasp_offset: float = 0.10 + + # Gripper + gripper: GripperConfig | None = None + + # Model loading + package_paths: dict[str, Path] = Field(default_factory=dict) + xacro_args: dict[str, str] = Field(default_factory=dict) + auto_convert_meshes: bool = True + srdf_path: Path | None = None + + # TF publishing + tf_extra_links: list[str] = Field(default_factory=list) + + # Task defaults + task_type: str = "trajectory" + task_priority: int = 10 + + # Collision exclusion pairs (gripper-specific, cannot be parsed from model) + collision_exclusion_pairs: list[tuple[str, str]] = Field(default_factory=list) + + _parsed: ModelDescription | None = PrivateAttr(default=None) + + def _ensure_parsed(self) -> ModelDescription: + """Parse model lazily on first access.""" + if self._parsed is None: + if self.model_path is None: + raise ValueError( + f"RobotConfig '{self.name}' has no model_path — " + "joint/link info is unavailable. Set model_path to a URDF/MJCF." + ) + self._parsed = parse_model(self.model_path, self.package_paths, self.xacro_args) + if self.joint_names is None: + self.joint_names = self._parsed.actuated_joint_names + if self.base_link is None: + self.base_link = self._parsed.root_link + if self.home_joints is None: + self.home_joints = self._compute_default_home() + return self._parsed + + def _compute_default_home(self) -> list[float]: + assert self._parsed is not None + home = [] + for joint_name in self.local_joint_names: + joint = self._parsed.get_joint(joint_name) + if ( + joint is not None + and joint.lower_limit is not None + and joint.upper_limit is not None + ): + home.append((joint.lower_limit + joint.upper_limit) / 2.0) + else: + home.append(0.0) + return home + + # -- Derived properties --------------------------------------------------- + + @property + def model_description(self) -> ModelDescription: + return self._ensure_parsed() + + @property + def local_joint_names(self) -> list[str]: + self._ensure_parsed() + assert self.joint_names is not None + return self.joint_names + + @property + def global_joint_names(self) -> list[str]: + return make_global_joint_names(self.name, self.local_joint_names) + + @property + def resolved_base_link(self) -> str: + self._ensure_parsed() + assert self.base_link is not None + return self.base_link + + @property + def dof(self) -> int: + if self.joint_names is not None: + return len(self.joint_names) + return len(self.local_joint_names) + + @property + def coordinator_task_name(self) -> str: + return f"traj_{self.name}" + + # -- Converter methods ---------------------------------------------------- + + def to_robot_model_config(self) -> RobotModelConfig: + """Generate RobotModelConfig for ManipulationModule.""" + if self.model_path is None: + raise ValueError( + f"RobotConfig '{self.name}' has no model_path — " + "cannot generate RobotModelConfig for manipulation." + ) + bp = self.base_pose + base_pose = PoseStamped( + position=Vector3(x=bp[0], y=bp[1], z=bp[2]), + orientation=Quaternion(bp[3], bp[4], bp[5], bp[6]), + ) + + exclusions = list(self.collision_exclusion_pairs) + if self.gripper: + exclusions.extend(self.gripper.collision_exclusions) + + # Use direct fields when available to avoid triggering model parsing at import time + joint_names = self.joint_names if self.joint_names is not None else self.local_joint_names + base_link = self.base_link if self.base_link is not None else self.resolved_base_link + planning_groups = discover_planning_group_definitions( + robot_name=self.name, + model_path=self.model_path, + model=self.model_description, + controllable_joint_names=joint_names, + srdf_path=self.srdf_path, + ) + legacy_end_effector_link = self.end_effector_link or next( + (group.tip_link for group in planning_groups if group.tip_link is not None), + None, + ) + + return RobotModelConfig( + name=self.name, + model_path=self.model_path, + srdf_path=self.srdf_path, + base_pose=base_pose, + strip_model_world_joint=self.strip_model_world_joint, + joint_names=joint_names, + end_effector_link=legacy_end_effector_link, + base_link=base_link, + planning_groups=planning_groups, + package_paths=self.package_paths, + xacro_args=self.xacro_args, + collision_exclusion_pairs=exclusions, + auto_convert_meshes=self.auto_convert_meshes, + max_velocity=self.max_velocity, + max_acceleration=self.max_acceleration, + coordinator_task_name=self.coordinator_task_name, + gripper_hardware_id=self.name if self.gripper else None, + tf_extra_links=self.tf_extra_links, + home_joints=self.home_joints, + pre_grasp_offset=self.pre_grasp_offset, + ) + + def to_hardware_component(self) -> HardwareComponent: + """Generate HardwareComponent for ControlCoordinator.""" + gripper_joints: list[str] = [] + if self.gripper and self.gripper.joints: + gripper_joints = make_global_joint_names(self.name, self.gripper.joints) + + adapter_kwargs = dict(self.adapter_kwargs) + if self.home_joints is not None: + adapter_kwargs.setdefault("initial_positions", self.home_joints) + + return HardwareComponent( + hardware_id=self.name, + hardware_type=HardwareType.MANIPULATOR, + joints=self.global_joint_names, + adapter_type=self.adapter_type, + address=self.address, + auto_enable=self.auto_enable, + gripper_joints=gripper_joints, + adapter_kwargs=adapter_kwargs, + ) + + def to_task_config( + self, + task_type: str | None = None, + task_name: str | None = None, + priority: int | None = None, + auto_start: bool = False, + **task_kwargs: Any, + ) -> TaskConfig: + """Generate TaskConfig for ControlCoordinator. + + Args: + task_type: Override task type (default: self.task_type). + task_name: Override task name (default: self.coordinator_task_name). + priority: Override priority (default: self.task_priority). + auto_start: Whether the coordinator should start this task on startup. + **task_kwargs: Task-specific params (e.g., model_path, + ee_joint_id, hand, gripper_joint, gripper_open_pos, gripper_closed_pos). + """ + params = dict(task_kwargs.pop("params", {})) + params.update(task_kwargs) + + return TaskConfig( + name=task_name if task_name is not None else self.coordinator_task_name, + type=task_type if task_type is not None else self.task_type, + joint_names=self.global_joint_names, + priority=priority if priority is not None else self.task_priority, + auto_start=auto_start, + params=params, + ) diff --git a/dimos/robot/manipulators/openarm/blueprints/planner.py b/dimos/robot/manipulators/openarm/blueprints/planner.py index 467461b012..9a341ca59b 100644 --- a/dimos/robot/manipulators/openarm/blueprints/planner.py +++ b/dimos/robot/manipulators/openarm/blueprints/planner.py @@ -17,15 +17,35 @@ from __future__ import annotations from dimos.core.coordination.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.robot.manipulators.common.blueprints import coordinator, planner -from dimos.robot.manipulators.openarm.blueprints.basic import ( - left_hw, - mock_left, - mock_right, - openarm_task, - right_hw, +from dimos.robot.manipulators.openarm.blueprints.basic import openarm_task +from dimos.robot.manipulators.openarm.config import ( + LEFT_CAN, + OPENARM_ADAPTER_KWARGS, + RIGHT_CAN, + openarm_hardware, + openarm_model_config, +) + +_mock_planner_left = openarm_hardware(side="left", scoped_joints=True) +_mock_planner_right = openarm_hardware(side="right", scoped_joints=True) + +_planner_left_hw = openarm_hardware( + side="left", + address=LEFT_CAN, + adapter_type="openarm", + adapter_kwargs=OPENARM_ADAPTER_KWARGS, + scoped_joints=True, +) +_planner_right_hw = openarm_hardware( + side="right", + address=RIGHT_CAN, + adapter_type="openarm", + adapter_kwargs=OPENARM_ADAPTER_KWARGS, + scoped_joints=True, ) -from dimos.robot.manipulators.openarm.config import openarm_model_config openarm_mock_planner_coordinator = autoconnect( planner( @@ -35,12 +55,18 @@ ], ), coordinator( - hardware=[mock_left, mock_right], + hardware=[_mock_planner_left, _mock_planner_right], tasks=[ - openarm_task(mock_left), - openarm_task(mock_right), + openarm_task(_mock_planner_left), + openarm_task(_mock_planner_right), ], ), +).transports( + { + ("coordinator_joint_state", JointState): LCMTransport( + "/coordinator/joint_state", JointState + ), + } ) openarm_planner_coordinator = autoconnect( @@ -51,10 +77,16 @@ ], ), coordinator( - hardware=[left_hw, right_hw], + hardware=[_planner_left_hw, _planner_right_hw], tasks=[ - openarm_task(left_hw), - openarm_task(right_hw), + openarm_task(_planner_left_hw), + openarm_task(_planner_right_hw), ], ), +).transports( + { + ("coordinator_joint_state", JointState): LCMTransport( + "/coordinator/joint_state", JointState + ), + } ) diff --git a/dimos/robot/manipulators/openarm/config.py b/dimos/robot/manipulators/openarm/config.py index 5306d408bf..e4faee3b90 100644 --- a/dimos/robot/manipulators/openarm/config.py +++ b/dimos/robot/manipulators/openarm/config.py @@ -63,15 +63,21 @@ def openarm_hardware( adapter_type: str = "mock", address: str | None = None, adapter_kwargs: dict[str, Any] | None = None, + scoped_joints: bool = False, ) -> HardwareComponent: validate_side(side) + resolved_name = name or f"{side}_arm" + local_joints = openarm_joints(side) + joints = ( + [f"{resolved_name}/{joint}" for joint in local_joints] if scoped_joints else local_joints + ) kwargs = {"side": side} if adapter_kwargs: kwargs.update(adapter_kwargs) return HardwareComponent( - hardware_id=name or f"{side}_arm", + hardware_id=resolved_name, hardware_type=HardwareType.MANIPULATOR, - joints=openarm_joints(side), + joints=joints, adapter_type=adapter_type, address=address, adapter_kwargs=kwargs, diff --git a/dimos/robot/manipulators/xarm/blueprints/basic.py b/dimos/robot/manipulators/xarm/blueprints/basic.py index 4f1660342e..47d1a45b9b 100644 --- a/dimos/robot/manipulators/xarm/blueprints/basic.py +++ b/dimos/robot/manipulators/xarm/blueprints/basic.py @@ -18,7 +18,9 @@ from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.coordination.blueprints import autoconnect +from dimos.core.transport import LCMTransport from dimos.manipulation.manipulation_module import ManipulationModule +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.robot.manipulators.common.blueprints import coordinator, planner, trajectory_task from dimos.robot.manipulators.common.sim import mujoco_if_sim from dimos.robot.manipulators.xarm.config import ( @@ -45,6 +47,29 @@ visualization={"backend": "meshcat"}, ) +_dual_xarm6_left_hw = xarm6_hardware("left_arm", mock_without_address=True) +_dual_xarm6_right_hw = xarm6_hardware("right_arm", mock_without_address=True) + +dual_xarm6_planner_coordinator = autoconnect( + planner( + robots=[ + make_xarm6_model_config(name="left_arm", y_offset=0.5), + make_xarm6_model_config(name="right_arm", y_offset=-0.5), + ], + visualization={"backend": "viser", "allow_plan_execute": True}, + ), + coordinator( + hardware=[_dual_xarm6_left_hw, _dual_xarm6_right_hw], + tasks=[trajectory_task(_dual_xarm6_left_hw), trajectory_task(_dual_xarm6_right_hw)], + ), +).transports( + { + ("coordinator_joint_state", JointState): LCMTransport( + "/coordinator/joint_state", JointState + ), + } +) + _xarm7_hw = xarm7_hardware("arm", gripper=True, mock_without_address=True) xarm7_planner_coordinator = autoconnect( diff --git a/dimos/robot/test_all_blueprints.py b/dimos/robot/test_all_blueprints.py index cdf72e9b6b..3497276e14 100644 --- a/dimos/robot/test_all_blueprints.py +++ b/dimos/robot/test_all_blueprints.py @@ -50,6 +50,7 @@ "coordinator-xarm6", "coordinator-xarm7", "dual-xarm6-planner", + "dual-xarm6-planner-coordinator", "teleop-hosted-go2", "teleop-hosted-xarm7", "teleop-quest-dual", diff --git a/docs/capabilities/manipulation/adding_a_custom_arm.md b/docs/capabilities/manipulation/adding_a_custom_arm.md index aaa075ee49..682a4f689f 100644 --- a/docs/capabilities/manipulation/adding_a_custom_arm.md +++ b/docs/capabilities/manipulation/adding_a_custom_arm.md @@ -483,6 +483,7 @@ Place your URDF/xacro files under LFS data so they can be resolved via `LfsPath` from dimos.utils.data import LfsPath from dimos.manipulation.manipulation_module import manipulation_module from dimos.manipulation.planning.spec import RobotModelConfig +from dimos.manipulation.planning.spec.models import PlanningGroupDefinition from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -506,7 +507,6 @@ def _make_base_pose(x=0.0, y=0.0, z=0.0) -> PoseStamped: def _make_yourarm_config( name: str = "arm", y_offset: float = 0.0, - joint_prefix: str = "", coordinator_task: str | None = None, ) -> RobotModelConfig: """Create YourArm robot config for planning. @@ -514,27 +514,32 @@ def _make_yourarm_config( Args: name: Robot name in the Drake planning world. y_offset: Y-axis offset for multi-arm setups. - joint_prefix: Prefix for joint name mapping to coordinator namespace. coordinator_task: Coordinator task name for trajectory execution via RPC. """ # These must match the joint names in your URDF joint_names = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] - joint_mapping = {f"{joint_prefix}{j}": j for j in joint_names} if joint_prefix else {} return RobotModelConfig( name=name, model_path=_YOURARM_URDF_PATH, - base_pose=_make_base_pose(y=y_offset), joint_names=joint_names, - end_effector_link="link6", # Last link in your URDF's kinematic chain - base_link="base_link", # Root link of your URDF + planning_groups=[ + PlanningGroupDefinition( + name="manipulator", + joint_names=tuple(joint_names), + base_link="base_link", + tip_link="link6", + source="fallback", + ) + ], + base_pose=_make_base_pose(y=y_offset), # Compatibility; prefer model placement + base_link="base_link", # Compatibility robot-scoped base package_paths={"yourarm_description": _YOURARM_PACKAGE_PATH}, xacro_args={}, # Xacro arguments if using .xacro files collision_exclusion_pairs=[], # Pairs of links that can touch (e.g., gripper fingers) auto_convert_meshes=True, # Convert DAE/STL meshes for Drake max_velocity=1.0, # Max velocity scaling factor max_acceleration=2.0, # Max acceleration scaling factor - joint_name_mapping=joint_mapping, coordinator_task_name=coordinator_task, ) ``` @@ -546,7 +551,7 @@ Add this to your `dimos/robot/yourarm/blueprints.py` alongside the coordinator b ```python skip yourarm_planner = manipulation_module( - robots=[_make_yourarm_config("arm", joint_prefix="arm_", coordinator_task="traj_arm")], + robots=[_make_yourarm_config("arm", coordinator_task="traj_arm")], planning_timeout=10.0, visualization={"backend": "meshcat"}, ) @@ -560,14 +565,22 @@ yourarm_planner = manipulation_module( | Field | Description | |-------|-------------| | `model_path` | Path to `.urdf` or `.xacro` file | -| `joint_names` | Ordered list of controlled joints (must match URDF) | -| `end_effector_link` | Link to use as the end-effector for IK | -| `base_link` | Root link of the robot model | +| `joint_names` | Ordered controllable local model joint set (must match URDF); not itself a planning group | +| `planning_groups` / `srdf_path` | Explicit planning groups or SRDF source; fallback can generate `{robot_name}/manipulator` for an unambiguous single chain | | `package_paths` | Maps `package://` URIs to filesystem paths (for xacro) | -| `joint_name_mapping` | Maps coordinator names (e.g., `"arm_joint1"`) to URDF names (e.g., `"joint1"`) | | `coordinator_task_name` | Must match the `TaskConfig.name` in your coordinator blueprint | | `collision_exclusion_pairs` | List of `(link_a, link_b)` tuples for links that may legitimately touch (e.g., gripper fingers) | +Coordinator-facing joint states and trajectories use global joint names derived +mechanically as `{robot_name}/{local_joint_name}` (for example, `arm/joint1`). +Keep hardware-native name translation inside the hardware adapter; manipulation +planning config uses local model joint names. + +`base_link`, `base_pose`, and `end_effector_link` are compatibility fields used +by current placement and robot-scoped helper paths. New planning code should use +SRDF/planning-group chain base/tip links and encode robot placement in the model. See +[Planning Groups](/docs/capabilities/manipulation/planning_groups.md). + ## Step 5: Register Blueprints The blueprint registry in `dimos/robot/all_blueprints.py` is **auto-generated** by scanning the codebase for blueprint declarations. After adding your blueprints: diff --git a/docs/capabilities/manipulation/planning_groups.md b/docs/capabilities/manipulation/planning_groups.md new file mode 100644 index 0000000000..7b86182778 --- /dev/null +++ b/docs/capabilities/manipulation/planning_groups.md @@ -0,0 +1,190 @@ +# Manipulation Planning Groups + +Planning groups are named, selectable kinematic chains used by manipulation +planning. They separate the hardware robot identity from the part of the robot +being planned. + +## Concepts + +| Concept | Meaning | +|---------|---------| +| Planning group | A named serial chain of controllable robot joints. | +| Planning group ID | Stable API ID in the form `{robot_name}/{group_name}`. | +| Global joint name | Boundary-level joint name in the form `{robot_name}/{local_joint_name}`. | +| Local joint name | The joint name as it appears in the robot model. | +| Generated plan | Minimal planning artifact containing selected group IDs and one synchronized global-joint path. | +| Auxiliary group | A group selected for a pose request without receiving its own pose target. | + +Local URDF/SRDF joint names stay inside robot-scoped APIs, model parsing, and +backend internals. Group-scoped APIs, generated plans, preview, execution, and +coordinator boundaries use global joint names so two robots can safely have the +same local joint names. + +`PlanningGroup` descriptors returned by `list_planning_groups()` include both +namespaces: + +- `id`: public `{robot_name}/{group_name}` selector; +- `joint_names`: selected global joint names in group order; +- `local_joint_names`: selected local model joint names in group order; +- `base_link` / `tip_link`: model links for group kinematics. `tip_link=None` + means the group can participate as an auxiliary group but cannot receive a + pose target. + +## Planning group sources + +DimOS discovers planning groups in this order: + +1. Explicit `srdf_path` on `RobotConfig` / `RobotModelConfig`. +2. Conservative SRDF auto-discovery near the model path, with a visible warning. +3. Fallback generation of one `{robot_name}/manipulator` group if the configured + controllable joints form exactly one unambiguous serial chain. +4. Error if no SRDF exists and fallback cannot infer a single chain. + +Supported SRDF group forms: + +```xml + + + +``` + +```xml + + + + + +``` + +Unsupported SRDF forms are skipped with warnings: link groups, nested group +references, mixed group declarations, branching/non-serial groups, and SRDF +`` metadata. A chain group's `tip_link` is the pose target frame. +An ordered joint-list group may be pose-targeted only when DimOS can validate a +unique serial target frame. + +## Fallback behavior + +When no SRDF is available, fallback uses `RobotModelConfig.joint_names` as the +candidate controllable set. This field is the robot's ordered local model joint +set, not an implicit planning group. + +Fallback succeeds only when those joints form one unambiguous serial chain. It +allows prismatic joints in the middle of the chain and strips only terminal/tip +prismatic joints, which usually represent gripper fingers. The generated group +name is always `manipulator`. + +## Planning APIs + +Planning APIs select groups explicitly. Descriptors returned by +`ManipulationModule.list_planning_groups()` can be passed anywhere a group ID is +accepted; the module normalizes descriptors back to IDs and re-resolves current +world state. + +```python skip +# Discover groups. Each item is a PlanningGroup dataclass. +groups = manip.list_planning_groups() +arm = groups[0] + +# Joint-space planning for one group. Named targets may use global names... +manip.plan_to_joint_targets({ + arm.id: JointState( + name=["left_arm/joint1", "left_arm/joint2"], + position=[0.2, -0.1], + ) +}) + +# ...or local model names, but not a mix of both namespaces. +manip.plan_to_joint_targets({ + arm: JointState( + name=["joint1", "joint2"], + position=[0.2, -0.1], + ) +}) + +# Pose planning for an arm while a torso/waist group participates as free DOFs. +manip.plan_to_pose_targets( + {"robot/arm": target_pose}, + auxiliary_groups=["robot/torso"], +) + +plan = manip._last_plan +manip.preview_plan(plan) +manip.execute_plan(plan) +``` + +### Pose targets + +`plan_to_pose_targets()` accepts `Mapping[PlanningGroupID | PlanningGroup, +Pose]`. It wraps each `Pose` in the world frame before calling the group-scoped +IK path. Every key in `pose_targets` must refer to a pose-targetable group +(`tip_link` is not `None`). `auxiliary_groups` may include non-pose-targeted +groups whose joints should stay in the solve as free DOFs. + +`inverse_kinematics()` is the lower-level RPC. It accepts stamped poses keyed by +group ID plus optional auxiliary group IDs and returns an `IKResult` without +running collision filtering or planning. + +The compatibility wrapper `plan_to_pose(pose, robot_name=None)` still exists. It +selects the default pose-targetable group for the robot, then delegates to +`plan_to_pose_targets()`. + +### Joint targets + +`plan_to_joint_targets()` accepts `Mapping[PlanningGroupID | PlanningGroup, +JointState]`. + +For a group-scoped joint target: + +- an unnamed vector is interpreted in that group's joint order; +- named targets may use all-global names or all-local names; +- global and local names must not be mixed in one target; +- named targets must provide exactly the group's selected joints, with no + missing or extra joints. + +The compatibility wrapper `plan_to_joints(joints, robot_name=None)` still exists. +It selects the robot's default group, then delegates to +`plan_to_joint_targets()`. + +Robot-scoped state helpers such as `set_init_joints()` still use local model +joint names. When unnamed, those vectors are interpreted in full robot model +joint order. + +## Generated plans and execution + +A `GeneratedPlan` stores: + +- selected planning group IDs; +- a single synchronized path of `JointState` waypoints keyed by selected global + joint names; +- status, timing, path length, iteration count, and message metadata. + +Preview and execution project this path lazily. `preview_plan(plan=None, +duration=None, robot_name=None)` defaults to `_last_plan`; `robot_name` is only a +filter that rejects plans which do not affect the requested robot. Preview sends +the generated global-joint path to the world monitor for animation. + +`execute_plan(plan=None)` also defaults to `_last_plan`. It infers affected +robots from the selected groups, projects each waypoint back into each robot's +full local model joint order, fills unselected robot joints from current state, +then writes a coordinator `JointTrajectory` using global joint names. The +trajectory is dispatched to each robot's configured `coordinator_task_name`. +Controllers remain planning-group agnostic. + +Multi-task dispatch is not atomic in this change: if one trajectory task accepts +and a later task rejects, DimOS reports the rejection but does not roll back the +accepted task. + +## Compatibility planning config fields + +`RobotConfig.base_link`, `RobotConfig.base_pose`, +`RobotModelConfig.base_link`, `RobotModelConfig.base_pose`, and +`RobotModelConfig.end_effector_link` remain as compatibility fields for the +current Drake weld/placement behavior and robot-scoped compatibility helpers. +New planning logic should use model/SRDF structure and planning group base/tip +links instead. + +Robot placement should be encoded in URDF/xacro/MJCF. `joint_names` remains +supported and should describe the ordered controllable local model joint set, not +a planning group. `joint_name_mapping` can map external/coordinator joint names +back to local model joint names for adapters that publish scoped hardware names. +`coordinator_task_name` identifies the trajectory task used by `execute_plan()`. diff --git a/docs/capabilities/manipulation/readme.md b/docs/capabilities/manipulation/readme.md index e502947699..ca130272c8 100644 --- a/docs/capabilities/manipulation/readme.md +++ b/docs/capabilities/manipulation/readme.md @@ -176,7 +176,7 @@ visualization backend. | `keyboard-teleop-xarm7` | XArm7 7-DOF keyboard teleop with Drake viz | | `xarm6-planner-only` | XArm6 standalone planner (no coordinator) | | `xarm7-planner-coordinator` | XArm7 planner with coordinator integration | -| `dual-xarm6-planner` | Dual XArm6 planning | +| `dual-xarm6-planner-coordinator` | Dual XArm6 planning and execution with Viser | | `xarm-perception` | XArm7 + RealSense camera for perception | | `xarm-perception-agent` | XArm7 perception + LLM agent | | `xarm-perception-sim` | XArm7 simulation perception stack | @@ -194,6 +194,14 @@ visualization backend. [guide is here](/docs/capabilities/manipulation/adding_a_custom_arm.md) +## Planning Groups + +Manipulation planning uses explicit planning group IDs such as +`arm/manipulator` and global joint names such as `arm/joint1`. See +[Planning Groups](/docs/capabilities/manipulation/planning_groups.md) for SRDF +support, fallback generation, auxiliary groups, generated plans, and execution +projection. + ## Key Files | File | Description | diff --git a/pyproject.toml b/pyproject.toml index b5af8dc760..ebf899fcaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -591,6 +591,7 @@ ignore = [ "dimos/dashboard/dimos.rbl", "dimos/web/dimos_interface/themes.json", "dimos/manipulation/manipulation_module.py", + "dimos/manipulation/planning/world/drake_world.py", "dimos/manipulation/visualization/viser/test_viser_visualization.py", "dimos/navigation/nav_stack/modules/*/main.cpp", "dimos/navigation/nav_stack/common/*.hpp",