A user-friendly Python package for 6-DoF grasp detection using RGB-D images. This is a standalone implementation of the Local/Patch stage of RNGNet, providing three flexible inference modes: random centers, heatmap prediction, and mouse click interaction.
-
Three Inference Modes:
centers: Inference from user-provided 2D centersheatmap: Automatic center detection using 2D heatmap predictionclick: Interactive mouse click selection
-
Multi-Camera Support: Pre-trained models for RealSense and Kinect cameras
-
Easy-to-Use API: Simple Python interface with minimal dependencies
-
Fully Standalone:
RNGNet.pycontains complete implementation, no external dependencies
RNGNet_compile/
├── RNGNet.py # Main standalone implementation (~47KB)
├── RegionNormalizedGrasp/ # Original reference implementation
├── example.py # Minimal usage examples
├── test.py # Test and benchmark script
├── requirements.txt # Python dependencies
├── realsense.pth # Pre-trained RealSense model (~40MB)
├── kinect.pth # Pre-trained Kinect model (~40MB)
├── README.md # This file
└── LICENSE # MIT License
cd RNGNet_compile
pip install -r requirements.txt- Python 3.8+
- PyTorch 1.11+ with CUDA
- Open3D (for visualization)
- NumPy, SciPy, OpenCV, Pillow
pip install torch==1.11.0 torchvision --index-url https://download.pytorch.org/whl/cu113
pip install open3d opencv-python pillow numpy scipy matplotlibFor faster inference, compile RNGNet.py into a C extension:
pip install cython
python setup.py build_ext --inplaceThis produces RNGNet.cpython-*.so in the current directory. Python will automatically prefer the .so over the .py file on import. To verify which module is loaded:
python -c "import RNGNet; print(RNGNet.__file__)"
# → .../RNGNet.cpython-38-x86_64-linux-gnu.so (compiled)
# → .../RNGNet.py (source)Note: After editing RNGNet.py, remove the old .so and recompile for changes to take effect:
rm RNGNet*.so
python setup.py build_ext --inplaceimport numpy as np
from PIL import Image
from RNGNet import RngNet
# Load data
rgb = np.array(Image.open("0_rgb.png")) / 255.0 # (H, W, 3), [0, 1] IMPORTANT!
depth = np.array(Image.open("0_depth.png")) # (H, W), in mm
# Initialize detector
detector = RngNet(camera="realsense") # or "kinect"
# Random sample centers
H, W = depth.shape
centers_2d = np.random.rand(100, 2) * [W, H]
centers_2d = centers_2d.astype(np.float32)
# Run inference
raw_gg, _ = detector.infer_from_rgbd_centers(rgb, depth, centers_2d, return_2d=True)
# Optional post-processing (score filter, NMS, top-k)
pred_gg = detector.postprocess(raw_gg, score_thresh=0.5, max_grasps=50)
print(f"Detected {len(pred_gg)} grasps")
if len(pred_gg) > 0:
print(f"Best score: {pred_gg.scores.max():.3f}")# Initialize with anchornet (required for heatmap mode)
detector = RngNet(camera="realsense", use_anchornet=True)
# Run heatmap-based detection
raw_gg = detector.infer_from_rgbd_heatmap(rgb, depth)
pred_gg = detector.postprocess(raw_gg, score_thresh=0.5, max_grasps=50)
print(f"Heatmap detected {len(pred_gg)} grasps")
print(f"Best score: {pred_gg.scores.max():.3f}")detector = RngNet(camera="realsense")
# Open interactive window
raw_gg, centers = detector.infer_from_rgbd_click(rgb, depth)
pred_gg = detector.postprocess(raw_gg, score_thresh=0.5, max_grasps=50)
# Click on objects, a 3x3 grid will be created around each clickThe test.py script provides comprehensive testing capabilities:
# Basic test with 512 random centers
python test.py
# Custom number of centers
python test.py --num-centers 256
# With 2D visualization
python test.py --vis --save result.png
# With 3D visualization
python test.py --vis3d# Heatmap-based detection
python test.py --mode heatmap
# With visualization
python test.py --mode heatmap --vis --save result.png# Interactive mouse click
python test.py --mode click# Run performance benchmark (heatmap mode, 100 runs, same format as demo.py)
python test.py --benchmark --mode heatmap --num-runs 100
# Or benchmark centers mode
python test.py --benchmark --mode centers --num-runs 100
# Example output:
# avg time == 60.69 mspython test.py --help
# Key options:
# --rgb RGB # RGB image path
# --depth DEPTH # Depth image path
# --mode {centers,heatmap,click}
# --num-centers N # Number of centers (default: 512)
# --vis # 2D visualization
# --vis3d # 3D Open3D visualization
# --save SAVE # Save visualization path
# --top-k TOP_K # Show top K grasps
# --benchmark # Run performance test
# --num-runs N # Benchmark runs
# Post-processing options:
# --score-thresh THRESH # Score threshold (e.g., 0.5)
# --nms # Enable NMS
# --nms-trans METERS # NMS translation threshold (default: 0.03)
# --nms-rot DEGREES # NMS rotation threshold (default: 30)
# --max-grasps N # Keep only top N grasps
# --collision-detect # Enable collision detectionRngNet(
checkpoint_path=None, # Auto-select if None
camera="realsense", # "realsense" or "kinect"
intrinsics=None, # Custom 3x3 intrinsics matrix
use_anchornet=False, # Required for heatmap mode
args=None # Custom arguments
)All inference methods return raw predictions. Post-processing (score filtering, NMS, collision detection) is applied separately via postprocess().
Inference from RGB-D with user-provided 2D centers.
Parameters:
rgb: (H, W, 3) numpy array, range [0, 1]depth: (H, W) numpy array, in mmcenters_2d: (N, 2) numpy array, each row is [u, v] in input image spacereturn_2d: If True, also return the 2D centers used
Returns: GraspGroup - Raw predicted 6-DoF grasps. If return_2d=True, returns (GraspGroup, centers_2d).
Inference using 2D heatmap prediction for automatic center detection.
Parameters:
rgb,depth: Same as abovereturn_2d: If True, also return the detected 2D centers
Returns: GraspGroup - Raw predicted 6-DoF grasps. If return_2d=True, returns (GraspGroup, centers_2d).
Interactive inference with mouse click.
Parameters:
rgb,depth: Same as above
Returns: (GraspGroup, centers_2d) - Selected centers include 3x3 grid around each click
postprocess(pred_gg, score_thresh=None, nms_translation_thresh=None, nms_rotation_thresh=None, max_grasps=None, scene_points=None)
Apply optional filtering, NMS, collision detection, and top-k truncation to raw predictions.
Parameters:
pred_gg:GraspGroupfrom inference methodsscore_thresh: If provided, keep grasps withscore > threshnms_translation_thresh: NMS distance threshold in meters (e.g.,0.03)nms_rotation_thresh: NMS angle threshold in radians (e.g.,30.0 / 180.0 * np.pi)max_grasps: If provided, keep only topmax_graspsby scorescene_points: Optional (M, 3) or (M, 6) point cloud for collision detection
Returns: GraspGroup - Post-processed grasps.
import open3d as o3d
# Get gripper geometries manually
grasp_geos = grasps.to_open3d_geometry_list(scale=1.0)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(scene_points[:, :3])
o3d.visualization.draw_geometries([pcd] + grasp_geos)Input images are automatically resized to 640×360 (default resolution) before processing, matching the original implementation.
Run benchmark to measure inference time on your hardware:
python test.py --benchmark --mode heatmap --num-runs 100
# avg time == XX msRNGNet.py is a fully self-contained implementation containing:
- Model architectures:
AnchorGraspNet,PatchMultiGraspNet,HGGDResNet - Evaluation functions:
detect_6d_grasp_multi,get_rotation,get_thetas_widths - Data structures:
Grasp,GraspGroupwith NMS and visualization - Inference class:
RngNetwith three modes (centers/heatmap/click)
No external imports from RegionNormalizedGrasp/ are required.
- RGB Normalization: RGB must be in range [0, 1] (divide by 255)
- Depth Unit: Depth values must be in millimeters (mm)
- Image Dimensions: Uses transposed dimension system (W x H) consistent with original training
- Rotation Order: Theta (in-plane) → Gamma (tilt) → Beta (roll)
- Score Filtering: Default threshold is 0.5, followed by translation-based NMS (0.03m)
| Camera | Model File | fx | fy | cx | cy |
|---|---|---|---|---|---|
| RealSense | realsense.pth |
927.17 | 927.37 | 651.32 | 349.62 |
| Kinect | kinect.pth |
631.55 | 631.21 | 638.43 | 366.50 |
Default resolution: 1280x720 (automatically scaled if different)
Reduce number of centers:
pred_gg = detector.infer_from_rgbd_centers(rgb, depth, centers[:50])Check:
- Depth values are in mm (not meters)
- RGB must be in [0, 1] range (divide by 255) - especially important for heatmap mode
- Centers are within image bounds
Ensure checkpoint exists:
ls *.pth # Should show realsense.pth or kinect.pthOr specify full path:
detector = RngNet(checkpoint_path="/path/to/model.pth")@inproceedings{chen2024regionaware,
title={Region-aware Grasp Framework with Normalized Grasp Space for Efficient 6-DoF Grasping},
author={Siang Chen and Pengwei Xie and Wei Tang and Dingchang Hu and Yixiang Dai and Guijin Wang},
booktitle={8th Annual Conference on Robot Learning},
year={2024}
}This project is licensed under the MIT License - see LICENSE file for details.
- Original RNGNet paper: CoRL 2024
- Based on HGGD codebase
- Uses GraspNet-1Billion dataset