A deep learning research project for surgical team action recognition using Human Mesh Recovery (HMR) and 3D pose estimation. This project analyzes 3D human pose data from surgical videos to classify surgical team member actions and understand team dynamics during surgical procedures.
Team-Surg processes 3D joint positions extracted from surgical videos using SMPL-X human mesh recovery to classify actions into four categories:
- Tool usage (label 0)
- Walking (label 1)
- Observing (label 2)
- Instrument handling (label 3)
The project explores multiple deep learning approaches including MLP-Mixer, 3D CNNs, and Graph Neural Networks (GNNs) to analyze temporal sequences of 3D joint positions.
-
Multiple Model Architectures:
- MLP-Mixer for temporal sequence classification
- 3D CNN (C3D) for video-like joint data processing
- Graph Neural Networks (GCN/GAT) for spatial-temporal graph analysis
-
Comprehensive Empirical Metrics:
- Group attention analysis
- Collision detection
- Team proximity and dispersion metrics
- Tool engagement tracking
- Attention switching events
-
Experimental Framework:
- PyTorch Lightning training pipeline
- Multi-GPU distributed training support
- Weights & Biases integration for experiment tracking
- Automated ablation study scripts
team-surg/
├── lightning/ # Production training pipeline
│ ├── main.py # Training/testing entry point
│ ├── model.py # PyTorch Lightning modules
│ ├── net.py # MLP-Mixer architecture
│ ├── dataset.py # Data loaders
│ ├── metrics.py # Classification metrics
│ ├── joints.py # Joint name mappings
│ └── util.py # Utilities and callbacks
│
├── sandbox/ # Experimental GNN development
│ ├── main.py # GNN training entry point
│ ├── GNNDataLoader.py # PyG graph dataset
│ ├── net.py # GNN + Mixer models
│ ├── model.py # GNN Lightning modules
│ ├── dataset.py # Advanced preprocessing
│ ├── commands.py # Batch experiment runner
│ └── run_ablations.sh # Ablation study script
│
├── empirical/ # Empirical metrics computation
│ ├── pipeline.py # Main metrics orchestrator
│ ├── group_attn.py # Group attention detection
│ ├── attn.py # Individual attention tracking
│ ├── collide.py # Collision detection
│ ├── group_prox.py # Team proximity metrics
│ ├── group_dist.py # Distance change analysis
│ └── tool.py # Tool engagement detection
│
└── postprocess/ # Results analysis and visualization
├── post.py
├── post_attn.py
└── post_group.py
- Python 3.12+
- CUDA-capable GPU (recommended)
- Clone the repository:
git clone <repository-url>
cd team-surg- Create and activate a virtual environment:
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate- Install dependencies:
pip install -r requirements.txt- PyTorch 2.7.0
- PyTorch Lightning 2.5.1
- PyTorch Geometric (for GNN models)
- NumPy 2.2.6
- pandas 2.2.3
- scikit-learn
- wandb (for experiment tracking)
- timm (vision models library)
The project uses preprocessed 3D joint position data:
- Input: 3D joint positions from SMPL-X model (127 joints)
- Temporal sequences: Variable length (5-150 frames at 30 fps ≈ 5 seconds)
- Labels: 4-class action labels (0: tool, 1: walk, 2: observe, 3: instrument)
- Train/Val/Test split: Video-based split (C videos for training, E videos for validation/testing)
Example preprocessed data files:
lightning/action_dataset_joints_leg_sampled_5.pklsandbox/action_dataset_joints_leg_sampled_150.pkl
python lightning/main.py train \
--data_path action_dataset_joints_leg_sampled_5.pkl \
--batch_size 256 \
--max_epochs 100 \
--lr 0.001python sandbox/main.py train \
--model_type gnn \
--data_path action_dataset_joints_leg_sampled_150.pkl \
--batch_size 64 \
--max_epochs 100bash sandbox/run_ablations.shOr use the commands module:
python sandbox/commands.py run_experimentspython lightning/main.py test \
--checkpoint_path path/to/checkpoint.ckpt \
--data_path action_dataset_joints_leg_sampled_5.pklpython empirical/pipeline.py pipeline \
--data_path path/to/joint_data.pkl \
--output_dir results/empirical/Processes flattened temporal sequences of joint positions using alternating token-mixing and channel-mixing layers. Suitable for fixed-length sequences.
Location: lightning/net.py
Treats temporal joint data as 3D volumes, applying 3D convolutions to capture spatial-temporal patterns.
Location: Referenced in lightning/model.py
Constructs spatial-temporal graphs where joints are nodes and edges represent anatomical connections or temporal relationships. Supports both GCN and GAT architectures.
Location: sandbox/net.py, sandbox/GNNDataLoader.py
The project supports extensive ablation studies across:
- Sequence lengths: 5, 10, 25, 50, 75, 100, 125, 150 frames
- Joint groups: pelvis, arm, head, thorax, spine, leg
- Individual joints: wrists, elbows, eyes, head, ears
- Representations: pose vectors vs. raw joint positions
Configuration in sandbox/dataset.py and automated via sandbox/run_ablations.sh.
The empirical module computes surgical team dynamics metrics:
- Group Attention: Detects when team members focus on the same point
- Individual Attention: Tracks attention switching events per person
- Collisions: Detects proximity/collision events between people
- Group Proximity: Measures team dispersion and drift
- Distance Changes: Computes pairwise distance variations
- Tool Engagement: Detects tool engagement based on hand positions
- Multi-GPU Support: Distributed Data Parallel (DDP) training
- Mixed Precision: bf16 for faster training
- Callbacks: Early stopping, model checkpointing, learning rate monitoring
- Logging: Weights & Biases integration for experiment tracking
This is a research project. For questions or contributions, please open an issue or pull request.
MIT License - See LICENSE file for details.
If you use this code in your research, please cite:
@software{team_surg,
title={Team-Surg: Surgical Team Action Recognition},
author={[Your Name/Team]},
year={2026},
url={[repository-url]}
}- SMPL-X for human mesh recovery
- PyTorch Lightning for training framework
- PyTorch Geometric for GNN implementations