diff --git a/ros2/README.md b/ros2/README.md index 6316409..7dc008f 100644 --- a/ros2/README.md +++ b/ros2/README.md @@ -1,23 +1,61 @@ # ROS 2 integration -ROS 2 packages for BugBot. Status: **in progress.** +ROS 2 packages for BugBot. Bridges the holonomic X-omni drive into ROS 2 and runs the +corrective policy trained in [`../simulation/`](../simulation) both in sim and on the real +robot. -## Goal +> Status: authored, not yet built on a live ROS 2 install. The drive mix is unit-tested; +> `colcon build` + runtime testing are pending. Target: ROS 2 Humble/Jazzy on Ubuntu / WSL 2. -Bridge BugBot's holonomic drive and its sensor/pose data into ROS 2, and run the trained -corrective policy from [`../simulation/`](../simulation) both in simulation and on the real -robot. +## Package: `bugbot_ros` (ament_python) + +``` +ros2/bugbot_ros/ +├── package.xml setup.py setup.cfg +├── bugbot_ros/ +│ ├── mixing.py X-omni mix (matches firmware + simulation) -- pure NumPy +│ ├── deploy.py shared core: observation builder + PolicyController (sim-to-real) +│ ├── twist_bridge_node.py /cmd_vel -> 4 wheel commands (direct teleop path) +│ └── policy_node.py trained policy -> 4 wheel commands (learned corrective control) +├── launch/bugbot.launch.py +└── test/test_mixing_parity.py +``` + +### Nodes and topics + +| Node | Subscribes | Publishes | +| --- | --- | --- | +| `twist_bridge` | `/cmd_vel` (`geometry_msgs/Twist`) | `/bugbot/wheel_cmd` (`std_msgs/Float32MultiArray`, BR FR BL FL) | +| `policy_node` | `/cmd_vel`, `/imu/data` (`sensor_msgs/Imu`), `/bugbot/vel` (`geometry_msgs/Twist`) | `/bugbot/wheel_cmd` | + +`/cmd_vel` is a full `Twist` because BugBot is holonomic: `linear.x` = forward, `linear.y` += left (mapped to BugBot's right-positive lateral internally), `angular.z` = yaw. + +## Build and run (Ubuntu / WSL 2) + +```bash +mkdir -p ~/bugbot_ws/src && cd ~/bugbot_ws/src +ln -s /path/to/bugbot/ros2/bugbot_ros . +cd ~/bugbot_ws +colcon build --packages-select bugbot_ros +source install/setup.bash + +# Direct teleop path (no policy): +ros2 run bugbot_ros twist_bridge +# ... and drive it, e.g.: +ros2 topic pub /cmd_vel geometry_msgs/msg/Twist "{linear: {x: 0.2, y: 0.1}, angular: {z: 0.5}}" -## Planned +# Learned corrective controller (zero policy until you pass one): +ros2 launch bugbot_ros bugbot.launch.py policy_path:=/path/to/policy.pt +``` -- A node that subscribes to `/cmd_vel` (full `geometry_msgs/Twist`, because lateral motion - is real for a holonomic robot) and maps the commanded twist to the four wheel commands - using the same X-omni mix as the firmware and simulation. -- Sensor and odometry topics (IMU, range, pose/odom) published from the robot or the sim. -- A policy node that loads a trained policy and publishes wheel commands, with a shared - deploy core reused between the sim bridge and the real-robot bridge (UDP/serial). -- The Isaac ROS 2 bridge for closed-loop validation in simulation. +The policy is the one trained in `../simulation/`, exported to TorchScript. With no +`policy_path` the controller emits a zero action (robot holds still), which is a safe +bring-up default. -## Platform +## Sim-to-real -ROS 2 (Humble/Jazzy) on Ubuntu or WSL 2. Package type: `ament_python`. +`deploy.py` holds the observation layout and normalisation constants, kept identical to the +simulation env, so a sim-trained policy consumes the same observation on the real robot. +`/bugbot/wheel_cmd` is the common output; wiring it to the firmware's command channel +(WebSocket / UDP) is the remaining integration step. diff --git a/ros2/bugbot_ros/bugbot_ros/__init__.py b/ros2/bugbot_ros/bugbot_ros/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ros2/bugbot_ros/bugbot_ros/deploy.py b/ros2/bugbot_ros/bugbot_ros/deploy.py new file mode 100644 index 0000000..c825f1a --- /dev/null +++ b/ros2/bugbot_ros/bugbot_ros/deploy.py @@ -0,0 +1,94 @@ +"""Shared sim-to-real core for BugBot: observation assembly, normalisation, and +policy inference. Reused by the ROS nodes so the deployment logic is tested once +and kept identical to the simulation. + +The constants and observation layout below MUST match the simulation env +(``simulation/bugbot_tasks/bugbot_env_cfg.py`` / ``bugbot_env.py``): a policy +trained there consumes exactly this observation vector. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np + +from .mixing import normalise_wheels, twist_to_wheels + +# --- normalisation (keep in step with BugBotEnvCfg) --- +MAX_LIN_SPEED = 0.5 # m/s at |twist| = 1 +MAX_YAW_RATE = 3.0 # rad/s at |rot| = 1 +N_WHEELS = 4 +OBS_DIM = 10 # cmd(3) + meas_lin(2) + meas_yaw(1) + prev_action(4) + + +@dataclass +class RobotSensors: + """Latest sensor readings in SI units, body frame.""" + lin_vel_xy: np.ndarray = field(default_factory=lambda: np.zeros(2)) # m/s (vx, vy) + yaw_rate: float = 0.0 # rad/s + + +class ObservationBuilder: + """Assemble the 10-dim observation the policy was trained on. + + Layout (all normalised): [cmd_vx, cmd_vy, cmd_wz, meas_vx, meas_vy, meas_wz, + prev_action(4)]. + """ + + def __init__(self): + self.prev_action = np.zeros(N_WHEELS, dtype=np.float32) + + def build(self, cmd_twist, sensors: RobotSensors) -> np.ndarray: + cmd = np.asarray(cmd_twist, dtype=np.float32) # (vx, vy, wz) in SI + obs = np.empty(OBS_DIM, dtype=np.float32) + obs[0] = cmd[0] / MAX_LIN_SPEED + obs[1] = cmd[1] / MAX_LIN_SPEED + obs[2] = cmd[2] / MAX_YAW_RATE + obs[3] = sensors.lin_vel_xy[0] / MAX_LIN_SPEED + obs[4] = sensors.lin_vel_xy[1] / MAX_LIN_SPEED + obs[5] = sensors.yaw_rate / MAX_YAW_RATE + obs[6:10] = self.prev_action + return obs + + def set_prev_action(self, action) -> None: + self.prev_action = np.asarray(action, dtype=np.float32).reshape(N_WHEELS) + + +class PolicyController: + """Loads a TorchScript-exported policy and maps observations to wheel commands. + + Export the trained rsl_rl policy to TorchScript first (rsl_rl's exporter, or + the sim's play path). If ``policy_path`` is None, acts as a zero policy + (robot holds still), which is a safe default for bring-up. + """ + + def __init__(self, policy_path: Optional[str] = None, device: str = "cpu"): + self.builder = ObservationBuilder() + self._policy = None + self._device = device + if policy_path: + import torch # imported lazily so the twist bridge needs no torch + self._torch = torch + self._policy = torch.jit.load(policy_path, map_location=device).eval() + + def act(self, cmd_twist, sensors: RobotSensors) -> np.ndarray: + """Return four wheel commands in [-1, 1] (BR, FR, BL, FL).""" + obs = self.builder.build(cmd_twist, sensors) + if self._policy is None: + action = np.zeros(N_WHEELS, dtype=np.float32) + else: + with self._torch.inference_mode(): + t = self._torch.from_numpy(obs).float().unsqueeze(0).to(self._device) + action = self._policy(t).squeeze(0).cpu().numpy() + action = np.clip(action, -1.0, 1.0) + self.builder.set_prev_action(action) + return action + + +def twist_to_wheel_cmds(vx: float, vy_right: float, wz: float) -> np.ndarray: + """Direct (non-learned) mapping: body twist -> normalised wheel commands, + capped like the firmware. ``vy_right`` is lateral with right positive.""" + wheels = twist_to_wheels(vx / MAX_LIN_SPEED, vy_right / MAX_LIN_SPEED, wz / MAX_YAW_RATE) + return normalise_wheels(wheels, 1.0) diff --git a/ros2/bugbot_ros/bugbot_ros/mixing.py b/ros2/bugbot_ros/bugbot_ros/mixing.py new file mode 100644 index 0000000..2149cb3 --- /dev/null +++ b/ros2/bugbot_ros/bugbot_ros/mixing.py @@ -0,0 +1,46 @@ +"""BugBot X-omni drive mixing (ROS-side copy). + +Authoritative copy of the drive mix, kept identical to the firmware +(``firmware/main/lib/drivers/MotionLib.cpp``) and the simulation +(``simulation/bugbot_tasks/mixing.py``). Duplicated here so the ROS package +builds standalone with no cross-package import; if you change one, change all. + +Motor order (BR, FR, BL, FL). Body twist = (longitudinal, lateral, rot): +longitudinal = forward +, lateral = right +, rot = CCW / left +. +""" + +from __future__ import annotations + +import numpy as np + +MIX = np.array( + [ + [+1.0, +1.0, +1.0], # BR = +lon + lat + rot + [-1.0, +1.0, -1.0], # FR = -lon + lat - rot + [+1.0, -1.0, -1.0], # BL = +lon - lat - rot + [-1.0, -1.0, +1.0], # FL = -lon - lat + rot + ], + dtype=np.float64, +) +UNMIX = np.linalg.pinv(MIX) +WHEEL_ORDER = ("BR", "FR", "BL", "FL") + + +def twist_to_wheels(longitudinal: float, lateral: float, rot: float) -> np.ndarray: + """Body twist -> four wheel commands (BR, FR, BL, FL).""" + return MIX @ np.array([longitudinal, lateral, rot], dtype=np.float64) + + +def wheels_to_twist(wheels) -> np.ndarray: + """Four wheel commands -> body twist (longitudinal, lateral, rot).""" + return UNMIX @ np.asarray(wheels, dtype=np.float64) + + +def normalise_wheels(wheels, max_abs: float = 1.0): + """Scale wheel commands so the peak magnitude fits within ``max_abs`` (matches + the firmware ``drive()`` cap). Commands already in range are unchanged.""" + wheels = np.asarray(wheels, dtype=np.float64) + peak = float(np.max(np.abs(wheels))) + if peak <= max_abs or peak == 0.0: + return wheels + return wheels * (max_abs / peak) diff --git a/ros2/bugbot_ros/bugbot_ros/policy_node.py b/ros2/bugbot_ros/bugbot_ros/policy_node.py new file mode 100644 index 0000000..0ab8d8f --- /dev/null +++ b/ros2/bugbot_ros/bugbot_ros/policy_node.py @@ -0,0 +1,82 @@ +"""Learned corrective controller for BugBot. + +Runs a policy trained in the Isaac Lab simulation (exported to TorchScript) that +maps [command twist + measured motion + previous action] to four wheel commands, +applying the micro-corrections that counter the chaotic X-omni drive. Publishes +``/bugbot/wheel_cmd`` (BR, FR, BL, FL) at a fixed rate. + +Inputs: + /cmd_vel geometry_msgs/Twist desired body twist (linear.x, linear.y, angular.z) + /imu/data sensor_msgs/Imu measured yaw rate (angular_velocity.z) + /bugbot/vel geometry_msgs/Twist measured body linear velocity (vx, vy) [optional] + +Parameters: + policy_path (str) TorchScript policy; empty = zero policy (holds still), safe for bring-up + control_hz (float) control loop rate (default 50) + + ros2 run bugbot_ros policy_node --ros-args -p policy_path:=/path/to/policy.pt + +Frame note: ROS y is left+, BugBot lateral is right+, so lateral = -linear.y. +""" + +import numpy as np +import rclpy +from geometry_msgs.msg import Twist +from rclpy.node import Node +from sensor_msgs.msg import Imu +from std_msgs.msg import Float32MultiArray + +from .deploy import PolicyController, RobotSensors + + +class PolicyNode(Node): + def __init__(self): + super().__init__("bugbot_policy") + self.declare_parameter("policy_path", "") + self.declare_parameter("control_hz", 50.0) + policy_path = self.get_parameter("policy_path").get_parameter_value().string_value + control_hz = self.get_parameter("control_hz").get_parameter_value().double_value + + self._controller = PolicyController(policy_path or None) + self._cmd = np.zeros(3, dtype=np.float32) # (vx, vy_right, wz) SI + self._sensors = RobotSensors() + + self._pub = self.create_publisher(Float32MultiArray, "/bugbot/wheel_cmd", 10) + self.create_subscription(Twist, "/cmd_vel", self._on_cmd_vel, 10) + self.create_subscription(Imu, "/imu/data", self._on_imu, 10) + self.create_subscription(Twist, "/bugbot/vel", self._on_vel, 10) + self.create_timer(1.0 / max(control_hz, 1.0), self._on_tick) + + mode = "zero policy (bring-up)" if not policy_path else policy_path + self.get_logger().info(f"bugbot_policy up @ {control_hz:.0f} Hz, policy = {mode}") + + def _on_cmd_vel(self, msg: Twist): + self._cmd[:] = (msg.linear.x, -msg.linear.y, msg.angular.z) + + def _on_imu(self, msg: Imu): + self._sensors.yaw_rate = float(msg.angular_velocity.z) + + def _on_vel(self, msg: Twist): + self._sensors.lin_vel_xy = np.array([msg.linear.x, -msg.linear.y], dtype=np.float64) + + def _on_tick(self): + wheels = self._controller.act(self._cmd, self._sensors) + out = Float32MultiArray() + out.data = [float(w) for w in wheels] + self._pub.publish(out) + + +def main(args=None): + rclpy.init(args=args) + node = PolicyNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/ros2/bugbot_ros/bugbot_ros/twist_bridge_node.py b/ros2/bugbot_ros/bugbot_ros/twist_bridge_node.py new file mode 100644 index 0000000..7693c1d --- /dev/null +++ b/ros2/bugbot_ros/bugbot_ros/twist_bridge_node.py @@ -0,0 +1,54 @@ +"""Direct /cmd_vel -> wheel-command bridge (no policy). + +Subscribes to a standard ``geometry_msgs/Twist`` on ``/cmd_vel`` and publishes +the four normalised wheel commands (BR, FR, BL, FL) on +``/bugbot/wheel_cmd`` (``std_msgs/Float32MultiArray``) using the firmware X-omni +mix. This is the teleop / open-loop path; the learned corrective controller is +``policy_node``. + +Frame note: ROS uses y-left-positive (REP-103); BugBot's lateral is right- +positive, so lateral = -linear.y. + + ros2 run bugbot_ros twist_bridge +""" + +import rclpy +from geometry_msgs.msg import Twist +from rclpy.node import Node +from std_msgs.msg import Float32MultiArray + +from .deploy import twist_to_wheel_cmds + + +class TwistBridge(Node): + def __init__(self): + super().__init__("bugbot_twist_bridge") + self._pub = self.create_publisher(Float32MultiArray, "/bugbot/wheel_cmd", 10) + self._sub = self.create_subscription(Twist, "/cmd_vel", self._on_cmd_vel, 10) + self.get_logger().info("bugbot_twist_bridge up: /cmd_vel -> /bugbot/wheel_cmd (BR,FR,BL,FL)") + + def _on_cmd_vel(self, msg: Twist): + wheels = twist_to_wheel_cmds( + vx=msg.linear.x, + vy_right=-msg.linear.y, # ROS y is left+, BugBot lateral is right+ + wz=msg.angular.z, + ) + out = Float32MultiArray() + out.data = [float(w) for w in wheels] + self._pub.publish(out) + + +def main(args=None): + rclpy.init(args=args) + node = TwistBridge() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/ros2/bugbot_ros/launch/bugbot.launch.py b/ros2/bugbot_ros/launch/bugbot.launch.py new file mode 100644 index 0000000..d288765 --- /dev/null +++ b/ros2/bugbot_ros/launch/bugbot.launch.py @@ -0,0 +1,30 @@ +"""Launch the BugBot learned controller. + + ros2 launch bugbot_ros bugbot.launch.py policy_path:=/path/to/policy.pt + +With no policy_path it runs a zero policy (robot holds still) -- safe for bring-up. +Run the direct teleop bridge separately with: ros2 run bugbot_ros twist_bridge +""" + +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument +from launch.substitutions import LaunchConfiguration +from launch_ros.actions import Node + + +def generate_launch_description(): + policy_path = LaunchConfiguration("policy_path") + control_hz = LaunchConfiguration("control_hz") + return LaunchDescription( + [ + DeclareLaunchArgument("policy_path", default_value="", description="TorchScript policy (empty = zero policy)"), + DeclareLaunchArgument("control_hz", default_value="50.0", description="Control loop rate (Hz)"), + Node( + package="bugbot_ros", + executable="policy_node", + name="bugbot_policy", + output="screen", + parameters=[{"policy_path": policy_path, "control_hz": control_hz}], + ), + ] + ) diff --git a/ros2/bugbot_ros/package.xml b/ros2/bugbot_ros/package.xml new file mode 100644 index 0000000..a51fc54 --- /dev/null +++ b/ros2/bugbot_ros/package.xml @@ -0,0 +1,24 @@ + + + + bugbot_ros + 0.1.0 + ROS 2 bridge and learned controller for the BugBot X-omni holonomic robot. + BugBotLab + MIT + + rclpy + std_msgs + geometry_msgs + sensor_msgs + python3-numpy + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/ros2/bugbot_ros/resource/bugbot_ros b/ros2/bugbot_ros/resource/bugbot_ros new file mode 100644 index 0000000..e69de29 diff --git a/ros2/bugbot_ros/setup.cfg b/ros2/bugbot_ros/setup.cfg new file mode 100644 index 0000000..813573a --- /dev/null +++ b/ros2/bugbot_ros/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/bugbot_ros +[install] +install_scripts=$base/lib/bugbot_ros diff --git a/ros2/bugbot_ros/setup.py b/ros2/bugbot_ros/setup.py new file mode 100644 index 0000000..16c7c17 --- /dev/null +++ b/ros2/bugbot_ros/setup.py @@ -0,0 +1,30 @@ +import os +from glob import glob + +from setuptools import find_packages, setup + +package_name = "bugbot_ros" + +setup( + name=package_name, + version="0.1.0", + packages=find_packages(exclude=["test"]), + data_files=[ + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), + (os.path.join("share", package_name, "launch"), glob("launch/*.launch.py")), + ], + install_requires=["setuptools"], + zip_safe=True, + maintainer="BugBotLab", + maintainer_email="jerome.a.graves@gmail.com", + description="ROS 2 bridge and learned controller for the BugBot X-omni holonomic robot.", + license="MIT", + tests_require=["pytest"], + entry_points={ + "console_scripts": [ + "twist_bridge = bugbot_ros.twist_bridge_node:main", + "policy_node = bugbot_ros.policy_node:main", + ], + }, +) diff --git a/ros2/bugbot_ros/test/test_mixing_parity.py b/ros2/bugbot_ros/test/test_mixing_parity.py new file mode 100644 index 0000000..5ed8347 --- /dev/null +++ b/ros2/bugbot_ros/test/test_mixing_parity.py @@ -0,0 +1,31 @@ +"""Verify the ROS-side X-omni mix matches the firmware basis vectors. +Pure NumPy, runnable without ROS: + + python ros2/bugbot_ros/test/test_mixing_parity.py +""" + +import os +import sys + +import numpy as np + +# Import mixing.py directly (avoid needing the installed ROS package). +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "bugbot_ros")) +from mixing import twist_to_wheels, wheels_to_twist # noqa: E402 + + +def test_basis_matches_firmware(): + np.testing.assert_allclose(twist_to_wheels(1, 0, 0), [+1, -1, +1, -1]) # forward + np.testing.assert_allclose(twist_to_wheels(0, 1, 0), [+1, +1, -1, -1]) # strafe right + np.testing.assert_allclose(twist_to_wheels(0, 0, 1), [+1, -1, -1, +1]) # turn left + + +def test_roundtrip(): + for tw in [(0.3, -0.2, 0.5), (-1.0, 0.0, 0.0)]: + np.testing.assert_allclose(wheels_to_twist(twist_to_wheels(*tw)), tw, atol=1e-9) + + +if __name__ == "__main__": + test_basis_matches_firmware() + test_roundtrip() + print("ros mix parity OK")