From 5a2dc09948fab064ff992e7c509fd289d5ac0331 Mon Sep 17 00:00:00 2001 From: Michael Johnson Date: Tue, 16 Jun 2026 11:42:50 +0100 Subject: [PATCH] feat(costmap): generic grid projection primitives and typed layer access MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the substrate needed to feed an arbitrary, caller-owned Grid2d (terrain class, learned traversability, instance ids, ...) into the u8 master costmap as a first-class layer, and to read it back for control — without putting any semantic/learned types in the crate itself. - project_into + MergePolicy (src/costmap/project.rs): the keystone generic projection of Grid2d -> master via a T -> Option closure. merge_overwrite/merge_max/merge_max_keep_unknown now delegate to it (one codepath; signatures unchanged). - Typed layer access (src/costmap/layered.rs): Layer: Any, LayerId, and add_layer returning a LayerId plus layer/layer_mut downcast accessors, so a layer can own its grid and the caller ingests/queries through &mut LayeredCostmap (no Arc/RwLock). Also makes every layer's internal grid inspectable. - ProjectionLayer (src/layers/projection.rs): an owned-source layer (nav2 CostmapLayer generalized over T + a projection closure). For rolling windows it re-centres the source in update_bounds so producers that fill during update_costs see an already-centred grid. - cost_from_unit/cost_from_range (src/types/cost.rs) and Grid2d::layout_matches (src/grid/grid2d.rs). - examples/learned_traversability.rs: end-to-end rolling-window demo (mock model -> ProjectionLayer -> inflation -> read-back control decision). SimLidarLayer in local_costmap_lidar.rs refactored to compose ProjectionLayer, validating the abstraction against the existing pattern. Tests cover projection policies, layout_matches, cost conversion, typed access, update_bounds re-centring/extent, the rolling end-to-end path, and projection seeding inflation. Co-Authored-By: Claude Opus 4.8 --- Cargo.toml | 5 + examples/learned_traversability.rs | 190 +++++++++++++++ examples/local_costmap_lidar.rs | 45 ++-- src/costmap/layered.rs | 39 ++- src/costmap/merge.rs | 57 ++--- src/costmap/mod.rs | 4 +- src/costmap/project.rs | 238 ++++++++++++++++++ src/grid/grid2d.rs | 89 +++++++ src/layers/mod.rs | 2 + src/layers/projection.rs | 371 +++++++++++++++++++++++++++++ src/lib.rs | 6 +- src/types/cost.rs | 77 ++++++ src/types/mod.rs | 2 + 13 files changed, 1059 insertions(+), 66 deletions(-) create mode 100644 examples/learned_traversability.rs create mode 100644 src/costmap/project.rs create mode 100644 src/layers/projection.rs create mode 100644 src/types/cost.rs diff --git a/Cargo.toml b/Cargo.toml index 03f7588..1af33b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,11 @@ name = "occupancy_raycast" path = "examples/occupancy_raycast.rs" required-features = ["rerun"] +[[example]] +name = "learned_traversability" +path = "examples/learned_traversability.rs" +required-features = ["rerun"] + [[bench]] name = "load_warehouse" harness = false diff --git a/examples/learned_traversability.rs b/examples/learned_traversability.rs new file mode 100644 index 0000000..2008ffd --- /dev/null +++ b/examples/learned_traversability.rs @@ -0,0 +1,190 @@ +//! # Local Costmap with Learned Traversability +//! +//! This example demonstrates a rolling-window local costmap where a mock "learned traversability" +//! model fills a semantic grid each frame. The model predicts terrain traversability (0.0 = free, +//! 1.0 = impassable), which is projected into the master costmap and then inflated for safe planning. +//! +//! The pipeline is: **semantic traversability layer** (mock learned model output) → **projection** +//! (conversion to cost space) → **inflation layer** → single master grid. +//! +//! ## Robotics Context +//! This pattern demonstrates how learned/perception-based traversability models can be integrated +//! into a costmap pipeline: +//! - **Semantic segmentation**: A neural network predicts traversability for each cell (e.g., 0.9 +//! for rocky terrain, 0.0 for smooth grass). +//! - **Projection**: Convert semantic values to costmap costs via a configurable closure. +//! - **Rolling window**: Keep the semantic grid centered on the robot as it moves. +//! - **Inflation**: Expand obstacles to account for robot size and safety margins. +//! +//! The robot in this example drives across a high-cost terrain band and the control decision is +//! printed when high-traversability costs are detected underneath the robot. + +use std::error::Error; +use std::time::Duration; + +use costmap::rerun_viz::{log_costmap, log_point3d}; +use costmap::types::{COST_FREE, Pose2}; +use costmap::{Grid2d, MapInfo, MergePolicy, ProjectionLayer, cost_from_unit}; +use costmap::{InflationConfig, LayeredCostmap, WavefrontInflationLayer}; +use glam::{UVec2, Vec2, Vec3}; + +// Simulation parameters +const DELAY_MS: u64 = 100; +const WAYPOINT_SPEED_MPS: f32 = 1.1; +const WAYPOINTS: &[(f32, f32)] = &[(2.0, 3.0), (7.0, 3.0), (7.0, 5.0), (2.0, 5.0)]; + +// Local costmap configuration +/// Size of the robot-centered rolling window costmap (square, in cells) +const LOCAL_SIZE_CELLS: u32 = 120; +const RESOLUTION: f32 = 0.1; + +// Visualization Z-heights for layering elements in Rerun +const Z_LOCAL: f32 = 0.12; +const Z_ROBOT: f32 = 0.35; + +fn main() -> Result<(), Box> { + let resolution = RESOLUTION; + let local_info = MapInfo::square(LOCAL_SIZE_CELLS, resolution); + let rec = rerun::RecordingStreamBuilder::new("costmap_learned_traversability").spawn()?; + + let mut layered = LayeredCostmap::new(local_info, COST_FREE, true); + let id = layered.add_layer(Box::new(ProjectionLayer::from_grid( + Grid2d::::new_with_value(local_info, 0.0), + MergePolicy::Max, + true, // rolling window + |t| Some(cost_from_unit(*t)), + ))); + layered.add_layer(Box::new(WavefrontInflationLayer::new(InflationConfig { + inflation_radius_m: 0.4, + inscribed_radius_m: 0.1, + cost_scaling_factor: 3.0, + ..Default::default() + }))); + + let waypoints: Vec = WAYPOINTS.iter().map(|(x, y)| Vec2::new(*x, *y)).collect(); + let (segment_lengths, total_length) = build_segments(&waypoints); + let dt = DELAY_MS as f32 / 1000.0; + + let mut frame_idx: i64 = 0; + loop { + rec.set_time_sequence("frame", frame_idx); + let distance = (frame_idx as f32) * WAYPOINT_SPEED_MPS * dt; + let (robot_pos, heading_dir) = + sample_path(&waypoints, &segment_lengths, total_length, distance); + let heading = heading_dir.y.atan2(heading_dir.x); + let robot = Pose2::new(robot_pos, heading); + + // Mock "learned traversability model": fill the source world-anchored. The + // rolling-window ProjectionLayer re-centres the grid on the robot during + // update_map, and update_origin preserves data by world position, so we just + // write each cell's cost from its current world location. + { + let proj = layered.layer_mut::>(id).unwrap(); + let src = proj.source_mut(); + let (w, h) = (src.width(), src.height()); + for y in 0..h { + for x in 0..w { + let cell = UVec2::new(x, y); + let world = src.map_to_world(cell); + // Impassable terrain band in world coords (x in [4.0, 4.5]). + // Traversability 1.0 projects to COST_LETHAL, which seeds inflation. + let t = if world.x > 4.0 && world.x < 4.5 { + 1.0 + } else { + 0.0 + }; + let _ = src.set(cell, t); + } + } + } + + layered.update_map(robot); + + // Control decision: read traversability back under the robot. + { + let proj = layered.layer::>(id).unwrap(); + let src = proj.source(); + if let Some(cell) = src.world_to_map(robot_pos) { + let t = src.get(cell).copied().unwrap_or(0.0); + if t > 0.5 { + println!( + "[frame {frame_idx}] high-cost terrain (t={t:.2}) under robot -> SLOW DOWN" + ); + } + } + } + + log_costmap(&rec, "world/local_costmap", layered.master(), Z_LOCAL)?; + log_point3d( + &rec, + "world/robot", + Vec3::new(robot_pos.x, robot_pos.y, Z_ROBOT), + Some(rerun::Color::from_rgb(0, 200, 255)), + Some(5.0), + )?; + + if DELAY_MS > 0 { + std::thread::sleep(Duration::from_millis(DELAY_MS)); + } + frame_idx = frame_idx.wrapping_add(1); + } +} + +// Helper functions for path following - not core library APIs + +fn build_segments(waypoints: &[Vec2]) -> (Vec, f32) { + let mut lengths = Vec::with_capacity(waypoints.len()); + let mut total = 0.0; + + if waypoints.len() < 2 { + return (lengths, total); + } + + for idx in 0..waypoints.len() { + let a = waypoints[idx]; + let b = waypoints[(idx + 1) % waypoints.len()]; + let len = (b - a).length(); + lengths.push(len); + total += len; + } + + (lengths, total) +} + +fn sample_path( + waypoints: &[Vec2], + segment_lengths: &[f32], + total_length: f32, + distance: f32, +) -> (Vec2, Vec2) { + if waypoints.is_empty() { + return (Vec2::ZERO, Vec2::X); + } + + if waypoints.len() == 1 || total_length <= 0.0 { + return (waypoints[0], Vec2::X); + } + + let mut remaining = distance.rem_euclid(total_length); + + for (idx, seg_len) in segment_lengths.iter().enumerate() { + if *seg_len <= 0.0 { + continue; + } + if remaining <= *seg_len { + let a = waypoints[idx]; + let b = waypoints[(idx + 1) % waypoints.len()]; + let dir = (b - a).normalize_or_zero(); + let pos = a + dir * remaining; + let heading = if dir.length_squared() == 0.0 { + Vec2::X + } else { + dir + }; + return (pos, heading); + } + remaining -= *seg_len; + } + + (waypoints[0], Vec2::X) +} diff --git a/examples/local_costmap_lidar.rs b/examples/local_costmap_lidar.rs index 74626dd..76a3e69 100644 --- a/examples/local_costmap_lidar.rs +++ b/examples/local_costmap_lidar.rs @@ -28,7 +28,7 @@ use costmap::rerun_viz::{log_costmap, log_occupancy_grid, log_point3d}; use costmap::types::{Bounds, COST_FREE, COST_LETHAL, COST_UNKNOWN, CellRegion, Pose2}; use costmap::{Costmap, raycast::RayHit2D}; use costmap::{Grid2d, MapInfo, OccupancyGrid, RosMapLoader, WavefrontInflationLayer}; -use costmap::{InflationConfig, costmap::merge_overwrite}; +use costmap::{InflationConfig, MergePolicy, ProjectionLayer}; use costmap::{Layer, LayeredCostmap}; use glam::{Vec2, Vec3}; @@ -53,15 +53,21 @@ const Z_LOCAL: f32 = 0.12; const Z_ROBOT: f32 = 0.35; /// Example-only layer: simulates lidar by raycasting on a global occupancy grid. -/// Keeps an internal obstacle grid (like Nav2's layer costmap_) so observations -/// persist across frames; each update we shift it, draw new rays, then write to master. +/// +/// Composes a [`ProjectionLayer`] for the storage + rolling-window + merge half +/// (the Nav2 `CostmapLayer` role): observations persist across frames in its owned +/// `source` grid, and projecting it into the master is the library's job. This layer +/// only adds the sensor-specific part — drawing rays into that grid each update. /// /// This would normally be done by listening to a laser scan topic which would /// update the obstacle layer. struct SimLidarLayer { global_grid: Arc, - /// Internal costmap that persists between updates (Nav2-style layer costmap_). - obstacle_grid: Costmap, + /// Owns the persistent obstacle grid and projects it into the master. With + /// `rolling_window: true` the projection layer re-centres the grid on the robot in + /// its `update_bounds`, so by the time we draw rays in `update_costs` the grid is + /// already centred. + proj: ProjectionLayer, last_robot: Pose2, max_range_m: f32, n_beams: usize, @@ -69,7 +75,7 @@ struct SimLidarLayer { impl Layer for SimLidarLayer { fn reset(&mut self) { - self.obstacle_grid.clear(); + self.proj.reset(); } fn is_clearable(&self) -> bool { @@ -78,13 +84,15 @@ impl Layer for SimLidarLayer { fn update_bounds(&mut self, robot: Pose2, bounds: &mut Bounds) { self.last_robot = robot; - bounds.expand_to_include(robot.position); - bounds.expand_by(self.max_range_m); + // Delegate to the projection layer: re-centres the obstacle grid (rolling + // window) and expands the bounds to its extent. + self.proj.update_bounds(robot, bounds); } fn update_costs(&mut self, master: &mut Costmap, region: CellRegion) { - // 1) Update internal grid origin (rolling window) and draw new rays into it. - self.obstacle_grid.update_center(self.last_robot.position); + // The obstacle grid is already centred on the robot (done in update_bounds). + // Draw new sensor rays into it, then project it into the master. + let obstacle_grid = self.proj.source_mut(); let beam_step = TAU / self.n_beams as f32; for beam_idx in 0..self.n_beams { let angle = self.last_robot.yaw + beam_step * beam_idx as f32; @@ -94,18 +102,16 @@ impl Layer for SimLidarLayer { .raycast_dda(self.last_robot.position, dir, self.max_range_m); let t = RayHit2D::distance_or(hit, self.max_range_m); let endpoint = hit.map(|_| COST_LETHAL); - self.obstacle_grid - .clear_ray(self.last_robot.position, dir, t, COST_FREE, endpoint); + obstacle_grid.clear_ray(self.last_robot.position, dir, t, COST_FREE, endpoint); } - // 2) Write layer to master (do not copy unknown). - merge_overwrite(master, &self.obstacle_grid, region); + self.proj.update_costs(master, region); } } fn main() -> Result<(), Box> { // Step 1: Load the global map (static environment representation) let grid = RosMapLoader::load_from_yaml(DEFAULT_YAML_PATH)?; - let info = grid.info().clone(); + let info = *grid.info(); let global_grid = Arc::new(grid); // Step 2: Set up visualization (optional - Rerun is not required to use the library) @@ -117,10 +123,15 @@ fn main() -> Result<(), Box> { // Step 3: Create layered costmap (rolling window, sensor layer + inflation layer) let local_info = MapInfo::square(LOCAL_SIZE_CELLS, info.resolution); - let mut layered = LayeredCostmap::new(local_info.clone(), COST_FREE, true); + let mut layered = LayeredCostmap::new(local_info, COST_FREE, true); layered.add_layer(Box::new(SimLidarLayer { global_grid: Arc::clone(&global_grid), - obstacle_grid: Grid2d::::new_with_value(local_info.clone(), COST_UNKNOWN), + proj: ProjectionLayer::from_grid( + Grid2d::::new_with_value(local_info, COST_UNKNOWN), + MergePolicy::Overwrite, + true, // rolling window: re-centres the obstacle grid on the robot each update + |c| (*c != COST_UNKNOWN).then_some(*c), + ), last_robot: Pose2::default(), max_range_m: MAX_RANGE_M, n_beams: N_BEAMS, diff --git a/src/costmap/layered.rs b/src/costmap/layered.rs index 1cef137..f00b678 100644 --- a/src/costmap/layered.rs +++ b/src/costmap/layered.rs @@ -7,6 +7,8 @@ //! The update loop aggregates bounds from all layers, resets the master region, //! then calls each layer's `update_costs` in order. +use std::any::Any; + use glam::{UVec2, Vec2}; use crate::{ @@ -14,13 +16,24 @@ use crate::{ types::{Bounds, CellRegion, Footprint, MapInfo, Pose2}, }; +/// Identifier returned by [`LayeredCostmap::add_layer`], used to fetch a layer +/// back by its concrete type with [`LayeredCostmap::layer`] / +/// [`LayeredCostmap::layer_mut`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct LayerId(usize); + /// A layer is a component which acts on the base master grid. It may contain /// its own grid but doesn't need to. /// /// Each layer is processed in order. /// /// Each update is limited by a set of bounds passed from higher layers. -pub trait Layer { +/// +/// The `Any` supertrait enables [`LayeredCostmap::layer`] / +/// [`LayeredCostmap::layer_mut`] to downcast a stored layer back to its concrete +/// type so callers can ingest into or query a layer's internal grid after it has +/// been added to the stack. +pub trait Layer: Any { /// Reset the layer to its initial state. fn reset(&mut self); @@ -70,8 +83,30 @@ impl LayeredCostmap { } /// Add a layer. Order matters: layers are updated in insertion order. - pub fn add_layer(&mut self, layer: Box) { + /// + /// Returns a [`LayerId`] that can be passed to [`Self::layer`] / + /// [`Self::layer_mut`] to access the layer by its concrete type later. + pub fn add_layer(&mut self, layer: Box) -> LayerId { + let id = LayerId(self.layers.len()); self.layers.push(layer); + id + } + + /// Borrow a previously added layer as its concrete type `T`. + /// + /// Returns `None` if the id is unknown or the layer is not a `T`. + pub fn layer(&self, id: LayerId) -> Option<&T> { + let layer: &dyn Layer = self.layers.get(id.0)?.as_ref(); + (layer as &dyn Any).downcast_ref::() + } + + /// Mutably borrow a previously added layer as its concrete type `T`. + /// + /// Returns `None` if the id is unknown or the layer is not a `T`. This is the + /// path used to ingest data into a layer's internal grid between updates. + pub fn layer_mut(&mut self, id: LayerId) -> Option<&mut T> { + let layer: &mut dyn Layer = self.layers.get_mut(id.0)?.as_mut(); + (layer as &mut dyn Any).downcast_mut::() } /// Immutable reference to the master grid. diff --git a/src/costmap/merge.rs b/src/costmap/merge.rs index 10ae38c..925d369 100644 --- a/src/costmap/merge.rs +++ b/src/costmap/merge.rs @@ -3,60 +3,31 @@ //! **Assumption:** `master` and `source` share the same dimensions and alignment so that //! cell `(x, y)` in `region` is valid in both grids. -use glam::UVec2; - use crate::types::{COST_UNKNOWN, CellRegion}; -use super::Costmap; +use super::{Costmap, MergePolicy, project_into}; /// Copies source into master only where source is not unknown. pub fn merge_overwrite(master: &mut Costmap, source: &Costmap, region: CellRegion) { - for y in region.min.y..region.max.y { - for x in region.min.x..region.max.x { - let cell = UVec2::new(x, y); - if let Some(&cost) = source.get(cell) - && cost != COST_UNKNOWN - { - let _ = master.set(cell, cost); - } - } - } + project_into(master, source, region, MergePolicy::Overwrite, |cost| { + (*cost != COST_UNKNOWN).then_some(*cost) + }); } /// Merges source into master by taking the maximum cost; never writes unknown from the layer. pub fn merge_max(master: &mut Costmap, source: &Costmap, region: CellRegion) { - for y in region.min.y..region.max.y { - for x in region.min.x..region.max.x { - let cell = UVec2::new(x, y); - let Some(&src_cost) = source.get(cell) else { - continue; - }; - if src_cost == COST_UNKNOWN { - continue; - } - let old = master.get(cell).copied().unwrap_or(COST_UNKNOWN); - if old == COST_UNKNOWN || old < src_cost { - let _ = master.set(cell, src_cost); - } - } - } + project_into(master, source, region, MergePolicy::Max, |cost| { + (*cost != COST_UNKNOWN).then_some(*cost) + }); } /// Like [`merge_max`] but does not overwrite master cells that are unknown. pub fn merge_max_keep_unknown(master: &mut Costmap, source: &Costmap, region: CellRegion) { - for y in region.min.y..region.max.y { - for x in region.min.x..region.max.x { - let cell = UVec2::new(x, y); - let Some(&src_cost) = source.get(cell) else { - continue; - }; - if src_cost == COST_UNKNOWN { - continue; - } - let old = master.get(cell).copied().unwrap_or(COST_UNKNOWN); - if old != COST_UNKNOWN && old < src_cost { - let _ = master.set(cell, src_cost); - } - } - } + project_into( + master, + source, + region, + MergePolicy::MaxKeepUnknown, + |cost| (*cost != COST_UNKNOWN).then_some(*cost), + ); } diff --git a/src/costmap/mod.rs b/src/costmap/mod.rs index ab2448e..e86e581 100644 --- a/src/costmap/mod.rs +++ b/src/costmap/mod.rs @@ -1,8 +1,10 @@ pub mod layered; pub mod merge; +pub mod project; -pub use layered::{Layer, LayeredCostmap}; +pub use layered::{Layer, LayerId, LayeredCostmap}; pub use merge::{merge_max, merge_max_keep_unknown, merge_overwrite}; +pub use project::{MergePolicy, project_into}; use crate::Grid2d; diff --git a/src/costmap/project.rs b/src/costmap/project.rs new file mode 100644 index 0000000..24e4f9b --- /dev/null +++ b/src/costmap/project.rs @@ -0,0 +1,238 @@ +//! Generic projection of semantic grids into the costmap. +//! +//! Allows any `Grid2d` with a projection closure `T → Option` to be merged +//! into the master costmap using a specified merge policy. + +use glam::UVec2; + +use crate::types::{COST_UNKNOWN, CellRegion}; + +use super::Costmap; + +/// Merge policy for projecting a source grid into the master costmap. +/// +/// These policies mirror the behavior of Nav2's cost-combining layers, +/// respecting the special semantics of COST_UNKNOWN as "no information." +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MergePolicy { + /// Overwrite: set master to the projected cost (do not copy COST_UNKNOWN). + Overwrite, + /// Max: keep the maximum of master and source cost; never writes COST_UNKNOWN. + Max, + /// MaxKeepUnknown: like Max, but does not overwrite unknown cells in master. + MaxKeepUnknown, +} + +/// Project a source grid into the master costmap within a region. +/// +/// For each cell in `region`, applies the projection function `f` to the source value. +/// If `f` returns `Some(cost)`, the cost is merged into the master according to the policy. +/// If `f` returns `None`, the master cell is left untouched. +/// +/// # Panics +/// In debug builds, panics if the master and source grids do not have matching layouts +/// (width, height, resolution, origin). +pub fn project_into( + master: &mut Costmap, + source: &crate::Grid2d, + region: CellRegion, + policy: MergePolicy, + f: impl Fn(&T) -> Option, +) { + debug_assert!( + master.layout_matches(source), + "master and source grids must have matching layouts" + ); + + for y in region.min.y..region.max.y { + for x in region.min.x..region.max.x { + let cell = UVec2::new(x, y); + + // Read from source and apply projection. + let Some(src_val) = source.get(cell) else { + continue; + }; + let Some(src_cost) = f(src_val) else { + continue; + }; + + // Read current master cell (defaults to COST_UNKNOWN if not set). + let old = master.get(cell).copied().unwrap_or(COST_UNKNOWN); + + // Apply the policy. + let new_cost = match policy { + MergePolicy::Overwrite => Some(src_cost), + MergePolicy::Max => { + if old == COST_UNKNOWN || old < src_cost { + Some(src_cost) + } else { + None + } + } + MergePolicy::MaxKeepUnknown => { + if old != COST_UNKNOWN && old < src_cost { + Some(src_cost) + } else { + None + } + } + }; + + if let Some(cost) = new_cost { + let _ = master.set(cell, cost); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Grid2d; + use crate::types::{COST_FREE, MapInfo}; + + fn default_info() -> MapInfo { + MapInfo { + width: 10, + height: 10, + resolution: 0.1, + ..Default::default() + } + } + + fn default_region() -> CellRegion { + CellRegion { + min: UVec2::new(0, 0), + max: UVec2::new(10, 10), + } + } + + #[test] + fn test_project_into_none_leaves_master_untouched() { + let mut master = Costmap::new_with_value(default_info(), COST_FREE); + let source = Grid2d::::new_with_value(default_info(), 200); + + // Set a specific value to verify it's not changed + master.set(UVec2::new(5, 5), 100).unwrap(); + + // Projection that always returns None + project_into( + &mut master, + &source, + default_region(), + MergePolicy::Overwrite, + |_| None, + ); + + assert_eq!(master.get(UVec2::new(5, 5)), Some(&100)); + assert_eq!(master.get(UVec2::new(0, 0)), Some(&COST_FREE)); + } + + #[test] + fn test_project_into_overwrite() { + let mut master = Costmap::new_with_value(default_info(), COST_FREE); + let source = Grid2d::::new_with_value(default_info(), 100); + + let region = CellRegion { + min: UVec2::new(2, 2), + max: UVec2::new(5, 5), + }; + + project_into(&mut master, &source, region, MergePolicy::Overwrite, |v| { + Some(*v) + }); + + // Inside region + assert_eq!(master.get(UVec2::new(2, 2)), Some(&100)); + assert_eq!(master.get(UVec2::new(4, 4)), Some(&100)); + // Outside region + assert_eq!(master.get(UVec2::new(0, 0)), Some(&COST_FREE)); + assert_eq!(master.get(UVec2::new(9, 9)), Some(&COST_FREE)); + } + + #[test] + fn test_project_into_max() { + let mut master = Costmap::new_with_value(default_info(), COST_FREE); + master.set(UVec2::new(3, 3), 50).unwrap(); + master.set(UVec2::new(4, 4), 150).unwrap(); + + let source = Grid2d::::new_with_value(default_info(), 100); + + let region = CellRegion { + min: UVec2::new(2, 2), + max: UVec2::new(5, 5), + }; + + project_into(&mut master, &source, region, MergePolicy::Max, |v| Some(*v)); + + // 50 < 100 → update to 100 + assert_eq!(master.get(UVec2::new(3, 3)), Some(&100)); + // 150 > 100 → keep 150 + assert_eq!(master.get(UVec2::new(4, 4)), Some(&150)); + // COST_FREE (0) < 100 → update to 100 + assert_eq!(master.get(UVec2::new(2, 2)), Some(&100)); + } + + #[test] + fn test_project_into_max_keep_unknown() { + let mut master = Costmap::new_with_value(default_info(), COST_UNKNOWN); + master.set(UVec2::new(3, 3), 50).unwrap(); + master.set(UVec2::new(4, 4), 150).unwrap(); + + let source = Grid2d::::new_with_value(default_info(), 100); + + let region = CellRegion { + min: UVec2::new(2, 2), + max: UVec2::new(5, 5), + }; + + project_into( + &mut master, + &source, + region, + MergePolicy::MaxKeepUnknown, + |v| Some(*v), + ); + + // 50 < 100 → update to 100 + assert_eq!(master.get(UVec2::new(3, 3)), Some(&100)); + // 150 > 100 → keep 150 + assert_eq!(master.get(UVec2::new(4, 4)), Some(&150)); + // COST_UNKNOWN → keep COST_UNKNOWN (do not update) + assert_eq!(master.get(UVec2::new(2, 2)), Some(&COST_UNKNOWN)); + } + + #[test] + fn test_project_into_generic_type() { + let mut master = Costmap::new_with_value(default_info(), COST_FREE); + let source = Grid2d::::new_with_value(default_info(), 0.5); + + let region = default_region(); + + // Project f32 to u8 cost + project_into(&mut master, &source, region, MergePolicy::Overwrite, |v| { + Some((*v * 254.0).round() as u8) + }); + + assert_eq!(master.get(UVec2::new(5, 5)), Some(&127)); // 0.5 * 254 ≈ 127 + } + + #[test] + fn test_project_into_respects_unknown_in_source() { + let mut master = Costmap::new_with_value(default_info(), COST_FREE); + let mut source = Grid2d::::new_with_value(default_info(), 100); + source.set(UVec2::new(5, 5), COST_UNKNOWN).unwrap(); + + let region = default_region(); + + // Projection that treats COST_UNKNOWN as None + project_into(&mut master, &source, region, MergePolicy::Overwrite, |v| { + if *v == COST_UNKNOWN { None } else { Some(*v) } + }); + + // COST_UNKNOWN cell should not be updated + assert_eq!(master.get(UVec2::new(5, 5)), Some(&COST_FREE)); + // Other cells should be updated + assert_eq!(master.get(UVec2::new(3, 3)), Some(&100)); + } +} diff --git a/src/grid/grid2d.rs b/src/grid/grid2d.rs index 58c0eac..b98f5ca 100644 --- a/src/grid/grid2d.rs +++ b/src/grid/grid2d.rs @@ -109,6 +109,13 @@ impl Grid2d { self.info.height } + /// Check if this grid has the same layout (width, height, resolution, origin) as another grid. + /// + /// This is useful for validating that two grids can be safely merged or compared. + pub fn layout_matches(&self, other: &Grid2d) -> bool { + self.info() == other.info() + } + /// Returns a reference to the cell at `pos` without bounds checking. /// /// # Safety @@ -501,4 +508,86 @@ mod tests { assert_eq!(grid.get(UVec2::new(0, 1)), Some(&10)); assert_eq!(grid.get(UVec2::new(1, 1)), Some(&11)); } + + #[test] + fn test_layout_matches_equal() { + let grid_u8 = Grid2d::::init( + MapInfo { + width: 10, + height: 20, + resolution: 0.5, + origin: Vec2::new(1.0, 2.0), + }, + vec![0; 200], + ) + .unwrap(); + + let grid_f32 = Grid2d::::init( + MapInfo { + width: 10, + height: 20, + resolution: 0.5, + origin: Vec2::new(1.0, 2.0), + }, + vec![0.0; 200], + ) + .unwrap(); + + assert!(grid_u8.layout_matches(&grid_f32)); + assert!(grid_f32.layout_matches(&grid_u8)); + } + + #[test] + fn test_layout_matches_different_width() { + let grid_u8 = Grid2d::::init( + MapInfo { + width: 10, + height: 20, + resolution: 0.5, + origin: Vec2::ZERO, + }, + vec![0; 200], + ) + .unwrap(); + + let grid_f32 = Grid2d::::init( + MapInfo { + width: 15, + height: 20, + resolution: 0.5, + origin: Vec2::ZERO, + }, + vec![0.0; 300], + ) + .unwrap(); + + assert!(!grid_u8.layout_matches(&grid_f32)); + } + + #[test] + fn test_layout_matches_different_resolution() { + let grid_u8 = Grid2d::::init( + MapInfo { + width: 10, + height: 20, + resolution: 0.5, + origin: Vec2::ZERO, + }, + vec![0; 200], + ) + .unwrap(); + + let grid_f32 = Grid2d::::init( + MapInfo { + width: 10, + height: 20, + resolution: 0.1, + origin: Vec2::ZERO, + }, + vec![0.0; 200], + ) + .unwrap(); + + assert!(!grid_u8.layout_matches(&grid_f32)); + } } diff --git a/src/layers/mod.rs b/src/layers/mod.rs index 7d970dc..c45e403 100644 --- a/src/layers/mod.rs +++ b/src/layers/mod.rs @@ -1,3 +1,5 @@ pub mod inflation; +pub mod projection; pub use inflation::{InflationConfig, WavefrontInflationLayer}; +pub use projection::ProjectionLayer; diff --git a/src/layers/projection.rs b/src/layers/projection.rs new file mode 100644 index 0000000..3106058 --- /dev/null +++ b/src/layers/projection.rs @@ -0,0 +1,371 @@ +//! A layer that owns a generic semantic grid and projects it into the master. +//! +//! `ProjectionLayer` is the first-class way to integrate a custom `Grid2d` +//! into a [`LayeredCostmap`](crate::LayeredCostmap) for either global or rolling-window +//! costmaps. The caller can ingest data via [`Self::source_mut`], read it back via +//! [`Self::source`], and the layer handles both the geometric transformation (rolling +//! window centering) and cost merging. +//! +//! This layer pattern is inspired by Nav2's `CostmapLayer`, but generalized to +//! support any cell type `T` with a semantic → cost projection closure. + +use glam::Vec2; + +use crate::{ + Grid2d, + costmap::{Costmap, Layer, MergePolicy, project_into}, + types::{Bounds, CellRegion, Pose2}, +}; + +/// A layer that owns a semantic grid and projects it into the master costmap. +/// +/// Supports both global (fixed-frame) and rolling-window (robot-centered) operation. +/// When `rolling_window` is set, the source grid is re-centred on the robot during +/// [`update_bounds`](Layer::update_bounds) — before any layer's `update_costs` runs — +/// so producers that write into the source during `update_costs` see a centred grid. +#[allow(clippy::type_complexity)] +pub struct ProjectionLayer { + source: Grid2d, + policy: MergePolicy, + project: Box Option + Send + Sync>, + rolling_window: bool, + clearable: bool, +} + +impl ProjectionLayer { + /// Create a new projection layer from a grid and projection closure. + /// + /// # Arguments + /// + /// * `source` — the semantic grid owned by this layer. + /// * `policy` — the merge policy (Overwrite, Max, MaxKeepUnknown). + /// * `rolling_window` — if true, the source is centered on the robot each update. + /// * `project` — closure that converts `&T` → `Option` cost. `None` leaves the + /// master cell untouched; `Some(cost)` applies the policy. + pub fn from_grid( + source: Grid2d, + policy: MergePolicy, + rolling_window: bool, + project: impl Fn(&T) -> Option + Send + Sync + 'static, + ) -> Self { + Self { + source, + policy, + project: Box::new(project), + rolling_window, + clearable: true, + } + } + + /// Immutable reference to the source grid (query path). + pub fn source(&self) -> &Grid2d { + &self.source + } + + /// Mutable reference to the source grid (ingestion path). + /// + /// Use this to write semantic data into the grid between updates. Typically + /// called after `layer_mut::>(id)` and before `update_map`. + pub fn source_mut(&mut self) -> &mut Grid2d { + &mut self.source + } +} + +impl Layer for ProjectionLayer { + fn reset(&mut self) { + self.source.clear(); + } + + fn is_clearable(&self) -> bool { + self.clearable + } + + fn update_bounds(&mut self, robot: Pose2, bounds: &mut Bounds) { + // For a rolling window, re-centre the source on the robot here, before any + // layer's update_costs runs. This lets producers that write into the source + // during update_costs (e.g. the lidar example) stamp into an already-centred grid. + if self.rolling_window { + self.source.update_center(robot.position); + } + + // Expand bounds to the source grid's (now current) world extent. + let info = self.source.info(); + bounds.expand_to_include(info.origin); + bounds.expand_to_include(info.origin + Vec2::new(info.world_width(), info.world_height())); + } + + fn update_costs(&mut self, master: &mut Costmap, region: CellRegion) { + project_into(master, &self.source, region, self.policy, |cell| { + (self.project)(cell) + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{COST_FREE, COST_LETHAL, MapInfo}; + use crate::{InflationConfig, LayeredCostmap, WavefrontInflationLayer}; + + fn default_info() -> MapInfo { + MapInfo::square(5, 1.0) + } + + #[test] + fn test_projection_layer_via_layered_costmap() { + // Create a layered costmap with a projection layer that projects f32 to u8. + let mut layered = LayeredCostmap::new(default_info(), COST_FREE, false); + + let source = Grid2d::::new_with_value(default_info(), 0.0); + let id = layered.add_layer(Box::new(ProjectionLayer::from_grid( + source, + MergePolicy::Overwrite, + false, + |v| { + // Map 0.0 → COST_FREE, 1.0 → COST_LETHAL + Some((v * 254.0).round() as u8) + }, + ))); + + // Write a value into the source + { + let proj_layer = layered + .layer_mut::>(id) + .expect("layer should exist and be the right type"); + proj_layer + .source_mut() + .set(glam::UVec2::new(2, 2), 1.0) + .unwrap(); + } + + // Update the master + layered.update_map(Pose2::default()); + + // Check that the master reflects the projected cost + assert_eq!( + layered.master().get(glam::UVec2::new(2, 2)), + Some(&COST_LETHAL) + ); + // Other cells should remain free + assert_eq!( + layered.master().get(glam::UVec2::new(0, 0)), + Some(&COST_FREE) + ); + } + + #[test] + fn test_projection_layer_wrong_type_returns_none() { + let mut layered = LayeredCostmap::new(default_info(), COST_FREE, false); + + let source = Grid2d::::new_with_value(default_info(), 0.0); + let id = layered.add_layer(Box::new(ProjectionLayer::from_grid( + source, + MergePolicy::Overwrite, + false, + |v| Some((*v * 254.0) as u8), + ))); + + // Try to fetch as the wrong type + let wrong_type = layered.layer::(id); + assert!(wrong_type.is_none()); + + // Correct type should work + let correct_type = layered.layer::>(id); + assert!(correct_type.is_some()); + } + + #[test] + fn test_projection_layer_with_inflation() { + // Test that projection layer can be stacked with inflation. + let mut layered = LayeredCostmap::new(default_info(), COST_FREE, false); + + let source = Grid2d::::new_with_value(default_info(), 0.0); + let proj_id = layered.add_layer(Box::new(ProjectionLayer::from_grid( + source, + MergePolicy::Overwrite, + false, + |v| Some((*v * 254.0) as u8), + ))); + + // Add inflation after projection. Radius 1.5 m (> 1 cell) so the cell adjacent + // to a lethal seed gets an inflated halo. + layered.add_layer(Box::new(WavefrontInflationLayer::new(InflationConfig { + inflation_radius_m: 1.5, + inscribed_radius_m: 0.0, + cost_scaling_factor: 1.0, + ..Default::default() + }))); + + // Project a lethal cell (1.0 * 254 == COST_LETHAL) so it seeds inflation. + { + let proj_layer = layered.layer_mut::>(proj_id).unwrap(); + proj_layer + .source_mut() + .set(glam::UVec2::new(2, 2), 1.0) + .unwrap(); + } + + layered.update_map(Pose2::default()); + + // The projected lethal cost is present... + assert_eq!( + layered.master().get(glam::UVec2::new(2, 2)), + Some(&COST_LETHAL) + ); + // ...and it seeds inflation: the adjacent cell gets a halo in (FREE, LETHAL). + let neighbour = layered + .master() + .get(glam::UVec2::new(3, 2)) + .copied() + .unwrap(); + assert!( + neighbour > COST_FREE && neighbour < COST_LETHAL, + "expected inflated halo at (3,2), got {neighbour}" + ); + } + + #[test] + fn test_projection_layer_reset() { + let mut layered = LayeredCostmap::new(default_info(), COST_FREE, false); + + let source = Grid2d::::new_with_value(default_info(), 0.0); + let id = layered.add_layer(Box::new(ProjectionLayer::from_grid( + source, + MergePolicy::Overwrite, + false, + |v| Some((*v * 254.0) as u8), + ))); + + // Write data + { + let proj_layer = layered.layer_mut::>(id).unwrap(); + proj_layer + .source_mut() + .set(glam::UVec2::new(1, 1), 0.5) + .unwrap(); + } + + // Reset via layer interface + { + let proj_layer = layered.layer_mut::>(id).unwrap(); + proj_layer.reset(); + } + + // Verify data is cleared to fill_value + { + let proj_layer = layered.layer::>(id).unwrap(); + assert_eq!(proj_layer.source().get(glam::UVec2::new(1, 1)), Some(&0.0)); + } + } + + #[test] + fn test_projection_layer_is_clearable() { + let layer = ProjectionLayer::from_grid( + Grid2d::::new_with_value(default_info(), 0.0), + MergePolicy::Overwrite, + false, + |_| Some(100), + ); + assert!(layer.is_clearable()); + } + + #[test] + fn update_bounds_rolling_recenters_and_covers_source_extent() { + // 5x5 @ 1.0 m → a 5 m square, half-extent 2.5 m. + let mut layer = ProjectionLayer::from_grid( + Grid2d::::new_with_value(default_info(), 0.0), + MergePolicy::Overwrite, + true, + |_| Some(0), + ); + + let mut bounds = Bounds::empty(); + let robot = Pose2 { + position: Vec2::new(5.0, 5.0), + yaw: 0.0, + }; + layer.update_bounds(robot, &mut bounds); + + // Rolling: the source is re-centred on the robot during update_bounds. The + // origin snaps to whole cells (update_origin shifts by integer cell offsets), + // so centring on (5,5) with 1 m cells gives origin (2.0, 2.0), not (2.5, 2.5). + assert_eq!(layer.source().info().origin, Vec2::new(2.0, 2.0)); + // Bounds cover exactly the (re-centred) source world extent — i.e. where the + // layer can actually write, not some larger nominal range. + assert_eq!(bounds.min, Vec2::new(2.0, 2.0)); + assert_eq!(bounds.max, Vec2::new(7.0, 7.0)); + } + + #[test] + fn update_bounds_non_rolling_uses_fixed_extent() { + let mut layer = ProjectionLayer::from_grid( + Grid2d::::new_with_value(default_info(), 0.0), + MergePolicy::Overwrite, + false, + |_| Some(0), + ); + + let mut bounds = Bounds::empty(); + layer.update_bounds( + Pose2 { + position: Vec2::new(5.0, 5.0), + yaw: 0.0, + }, + &mut bounds, + ); + + // Not rolling: the source stays put and bounds are its fixed world extent. + assert_eq!(layer.source().info().origin, Vec2::ZERO); + assert_eq!(bounds.min, Vec2::ZERO); + assert_eq!(bounds.max, Vec2::new(5.0, 5.0)); + } + + #[test] + fn rolling_window_projects_to_correct_world_cell_as_window_moves() { + // The motivating path: a rolling master + rolling projection layer (same + // MapInfo) driven through update_map. A cost fixed in *world* space must land + // at the right master cell as the window moves, and master/source must stay + // layout-aligned (project_into's debug_assert fires here, in debug test builds). + let mut layered = LayeredCostmap::new(default_info(), COST_FREE, true); + let id = layered.add_layer(Box::new(ProjectionLayer::from_grid( + Grid2d::::new_with_value(default_info(), 0.0), + MergePolicy::Max, + true, + |t| (*t > 0.5).then_some(COST_LETHAL), + ))); + + // Robot centred so the window origin stays at (0,0). Mark world (3.5, 3.5). + layered + .layer_mut::>(id) + .unwrap() + .source_mut() + .set(glam::UVec2::new(3, 3), 1.0) + .unwrap(); + layered.update_map(Pose2 { + position: Vec2::new(2.5, 2.5), + yaw: 0.0, + }); + // World (3.5,3.5) → cell (3,3) with origin (0,0). + assert_eq!( + layered.master().get(glam::UVec2::new(3, 3)), + Some(&COST_LETHAL) + ); + + // Move the window +1 m in x → origin snaps to (1,0). No re-marking: the source + // persists data by world position across the shift. + layered.update_map(Pose2 { + position: Vec2::new(3.5, 2.5), + yaw: 0.0, + }); + // World (3.5,3.5) now maps to cell (2,3) with origin (1,0). + assert_eq!( + layered.master().get(glam::UVec2::new(2, 3)), + Some(&COST_LETHAL) + ); + // The old cell no longer holds the marked world location. + assert_eq!( + layered.master().get(glam::UVec2::new(3, 3)), + Some(&COST_FREE) + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 20e92b9..4a2869d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,8 +9,8 @@ pub mod types; #[cfg(feature = "rerun")] pub mod rerun_viz; -pub use costmap::{Costmap, Layer, LayeredCostmap}; +pub use costmap::{Costmap, Layer, LayerId, LayeredCostmap, MergePolicy, project_into}; pub use grid::Grid2d; -pub use layers::{InflationConfig, WavefrontInflationLayer}; +pub use layers::{InflationConfig, ProjectionLayer, WavefrontInflationLayer}; pub use loaders::ros2::{RosMapLoader, RosMapMetadata}; -pub use types::{MapInfo, OccupancyGrid, VoxelError}; +pub use types::{MapInfo, OccupancyGrid, VoxelError, cost_from_range, cost_from_unit}; diff --git a/src/types/cost.rs b/src/types/cost.rs new file mode 100644 index 0000000..cff41a2 --- /dev/null +++ b/src/types/cost.rs @@ -0,0 +1,77 @@ +//! Cost scale conversion utilities for mapping physical/semantic values to u8 costs. + +use super::constants::{COST_FREE, COST_LETHAL}; + +/// Convert a value in [0, 1] to a costmap cost in [COST_FREE, COST_LETHAL]. +/// +/// Values outside [0, 1] are clamped: values ≤ 0 → COST_FREE, values ≥ 1 → COST_LETHAL. +/// Values in between are linearly interpolated and rounded to the nearest u8. +pub fn cost_from_unit(x: f32) -> u8 { + let clamped = x.clamp(0.0, 1.0); + let cost_range = (COST_LETHAL as f32) - (COST_FREE as f32); + let scaled = (COST_FREE as f32) + clamped * cost_range; + scaled.round() as u8 +} + +/// Convert a value in range [lo, hi] to a costmap cost in [COST_FREE, COST_LETHAL]. +/// +/// The value `x` is normalized into [0, 1] based on the range [lo, hi], +/// then converted to a cost using [`cost_from_unit`]. +/// If `hi <= lo`, returns COST_FREE (no valid range). +pub fn cost_from_range(x: f32, lo: f32, hi: f32) -> u8 { + if hi <= lo { + return COST_FREE; + } + let normalized = (x - lo) / (hi - lo); + cost_from_unit(normalized) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cost_from_unit_boundaries() { + assert_eq!(cost_from_unit(0.0), COST_FREE); + assert_eq!(cost_from_unit(1.0), COST_LETHAL); + } + + #[test] + fn test_cost_from_unit_clamping() { + // Out of range values should be clamped + assert_eq!(cost_from_unit(-1.0), COST_FREE); + assert_eq!(cost_from_unit(-100.0), COST_FREE); + assert_eq!(cost_from_unit(2.0), COST_LETHAL); + assert_eq!(cost_from_unit(100.0), COST_LETHAL); + } + + #[test] + fn test_cost_from_unit_midpoint() { + let mid = cost_from_unit(0.5); + // Should be approximately in the middle of the range + let expected = ((COST_FREE as f32 + COST_LETHAL as f32) / 2.0).round() as u8; + // Allow ±1 for rounding differences + assert!((mid as i16 - expected as i16).abs() <= 1); + } + + #[test] + fn test_cost_from_range_valid() { + // Range [0, 10] with value 5 should map to 0.5 unit → ~127 + let cost = cost_from_range(5.0, 0.0, 10.0); + let expected_unit = cost_from_unit(0.5); + assert_eq!(cost, expected_unit); + } + + #[test] + fn test_cost_from_range_invalid() { + // hi <= lo should return COST_FREE + assert_eq!(cost_from_range(5.0, 10.0, 10.0), COST_FREE); + assert_eq!(cost_from_range(5.0, 10.0, 5.0), COST_FREE); + } + + #[test] + fn test_cost_from_range_bounds() { + assert_eq!(cost_from_range(0.0, 0.0, 10.0), COST_FREE); + assert_eq!(cost_from_range(10.0, 0.0, 10.0), COST_LETHAL); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index a281edd..2ae972e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,10 +1,12 @@ pub mod constants; +pub mod cost; pub mod error; pub mod geometry; pub mod info; pub mod occupancy_grid; pub use constants::*; +pub use cost::{cost_from_range, cost_from_unit}; pub use error::VoxelError; pub use geometry::{Bounds, CellRegion, Footprint, Pose2}; pub use info::MapInfo;