Skip to content

marlotea/team-surg

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

51 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Team-Surg: Surgical Team Action Recognition

A deep learning research project for surgical team action recognition using Human Mesh Recovery (HMR) and 3D pose estimation. This project analyzes 3D human pose data from surgical videos to classify surgical team member actions and understand team dynamics during surgical procedures.

Overview

Team-Surg processes 3D joint positions extracted from surgical videos using SMPL-X human mesh recovery to classify actions into four categories:

  • Tool usage (label 0)
  • Walking (label 1)
  • Observing (label 2)
  • Instrument handling (label 3)

The project explores multiple deep learning approaches including MLP-Mixer, 3D CNNs, and Graph Neural Networks (GNNs) to analyze temporal sequences of 3D joint positions.

Key Features

  • Multiple Model Architectures:

    • MLP-Mixer for temporal sequence classification
    • 3D CNN (C3D) for video-like joint data processing
    • Graph Neural Networks (GCN/GAT) for spatial-temporal graph analysis
  • Comprehensive Empirical Metrics:

    • Group attention analysis
    • Collision detection
    • Team proximity and dispersion metrics
    • Tool engagement tracking
    • Attention switching events
  • Experimental Framework:

    • PyTorch Lightning training pipeline
    • Multi-GPU distributed training support
    • Weights & Biases integration for experiment tracking
    • Automated ablation study scripts

Project Structure

team-surg/
├── lightning/              # Production training pipeline
│   ├── main.py            # Training/testing entry point
│   ├── model.py           # PyTorch Lightning modules
│   ├── net.py             # MLP-Mixer architecture
│   ├── dataset.py         # Data loaders
│   ├── metrics.py         # Classification metrics
│   ├── joints.py          # Joint name mappings
│   └── util.py            # Utilities and callbacks
│
├── sandbox/               # Experimental GNN development
│   ├── main.py            # GNN training entry point
│   ├── GNNDataLoader.py   # PyG graph dataset
│   ├── net.py             # GNN + Mixer models
│   ├── model.py           # GNN Lightning modules
│   ├── dataset.py         # Advanced preprocessing
│   ├── commands.py        # Batch experiment runner
│   └── run_ablations.sh   # Ablation study script
│
├── empirical/             # Empirical metrics computation
│   ├── pipeline.py        # Main metrics orchestrator
│   ├── group_attn.py      # Group attention detection
│   ├── attn.py            # Individual attention tracking
│   ├── collide.py         # Collision detection
│   ├── group_prox.py      # Team proximity metrics
│   ├── group_dist.py      # Distance change analysis
│   └── tool.py            # Tool engagement detection
│
└── postprocess/           # Results analysis and visualization
    ├── post.py
    ├── post_attn.py
    └── post_group.py

Installation

Prerequisites

  • Python 3.12+
  • CUDA-capable GPU (recommended)

Setup

  1. Clone the repository:
git clone <repository-url>
cd team-surg
  1. Create and activate a virtual environment:
python -m venv .venv
source .venv/bin/activate  # On Windows: .venv\Scripts\activate
  1. Install dependencies:
pip install -r requirements.txt

Key Dependencies

  • PyTorch 2.7.0
  • PyTorch Lightning 2.5.1
  • PyTorch Geometric (for GNN models)
  • NumPy 2.2.6
  • pandas 2.2.3
  • scikit-learn
  • wandb (for experiment tracking)
  • timm (vision models library)

Data Format

The project uses preprocessed 3D joint position data:

  • Input: 3D joint positions from SMPL-X model (127 joints)
  • Temporal sequences: Variable length (5-150 frames at 30 fps ≈ 5 seconds)
  • Labels: 4-class action labels (0: tool, 1: walk, 2: observe, 3: instrument)
  • Train/Val/Test split: Video-based split (C videos for training, E videos for validation/testing)

Example preprocessed data files:

  • lightning/action_dataset_joints_leg_sampled_5.pkl
  • sandbox/action_dataset_joints_leg_sampled_150.pkl

Usage

Training Models

MLP-Mixer (Lightning module):

python lightning/main.py train \
  --data_path action_dataset_joints_leg_sampled_5.pkl \
  --batch_size 256 \
  --max_epochs 100 \
  --lr 0.001

GNN Models (Sandbox):

python sandbox/main.py train \
  --model_type gnn \
  --data_path action_dataset_joints_leg_sampled_150.pkl \
  --batch_size 64 \
  --max_epochs 100

Running Ablation Studies:

bash sandbox/run_ablations.sh

Or use the commands module:

python sandbox/commands.py run_experiments

Testing Models

python lightning/main.py test \
  --checkpoint_path path/to/checkpoint.ckpt \
  --data_path action_dataset_joints_leg_sampled_5.pkl

Running Empirical Analysis

python empirical/pipeline.py pipeline \
  --data_path path/to/joint_data.pkl \
  --output_dir results/empirical/

Model Architectures

1. MLP-Mixer

Processes flattened temporal sequences of joint positions using alternating token-mixing and channel-mixing layers. Suitable for fixed-length sequences.

Location: lightning/net.py

2. 3D CNN (C3D)

Treats temporal joint data as 3D volumes, applying 3D convolutions to capture spatial-temporal patterns.

Location: Referenced in lightning/model.py

3. Graph Neural Networks

Constructs spatial-temporal graphs where joints are nodes and edges represent anatomical connections or temporal relationships. Supports both GCN and GAT architectures.

Location: sandbox/net.py, sandbox/GNNDataLoader.py

Ablation Studies

The project supports extensive ablation studies across:

  • Sequence lengths: 5, 10, 25, 50, 75, 100, 125, 150 frames
  • Joint groups: pelvis, arm, head, thorax, spine, leg
  • Individual joints: wrists, elbows, eyes, head, ears
  • Representations: pose vectors vs. raw joint positions

Configuration in sandbox/dataset.py and automated via sandbox/run_ablations.sh.

Empirical Metrics

The empirical module computes surgical team dynamics metrics:

  • Group Attention: Detects when team members focus on the same point
  • Individual Attention: Tracks attention switching events per person
  • Collisions: Detects proximity/collision events between people
  • Group Proximity: Measures team dispersion and drift
  • Distance Changes: Computes pairwise distance variations
  • Tool Engagement: Detects tool engagement based on hand positions

Training Infrastructure

  • Multi-GPU Support: Distributed Data Parallel (DDP) training
  • Mixed Precision: bf16 for faster training
  • Callbacks: Early stopping, model checkpointing, learning rate monitoring
  • Logging: Weights & Biases integration for experiment tracking

Contributing

This is a research project. For questions or contributions, please open an issue or pull request.

License

MIT License - See LICENSE file for details.

Citation

If you use this code in your research, please cite:

@software{team_surg,
  title={Team-Surg: Surgical Team Action Recognition},
  author={[Your Name/Team]},
  year={2026},
  url={[repository-url]}
}

Acknowledgments

  • SMPL-X for human mesh recovery
  • PyTorch Lightning for training framework
  • PyTorch Geometric for GNN implementations

About

Starter code for HMR-Surg projects

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 81.7%
  • Jupyter Notebook 18.0%
  • Shell 0.3%