Add support for Apple Silicon#153
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds Apple Silicon (MPS) support by making inference entry points and model autocast paths more device-aware while preserving CUDA/CPU execution.
Changes:
- Adds CUDA/MPS/CPU device selection and per-device precision choices in examples and Gradio demo.
- Replaces hardcoded CUDA autocast usage in model and pipeline code.
- Adds MPS cache clearing and updates ignore patterns.
Reviewed changes
Copilot reviewed 8 out of 9 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
example.py |
Adds MPS fallback setup, device auto-selection, and device-aware autocast dtype. |
example_mm.py |
Applies device-aware setup to multimodal inference. |
example_vo.py |
Applies device-aware setup to VO pipeline entry point. |
demo_gradio.py |
Allows Gradio inference/model loading on CUDA, MPS, or CPU. |
pi3/models/pi3.py |
Makes disabled autocast block use the tensor device type. |
pi3/models/pi3x.py |
Makes multimodal/head disabled autocast blocks use tensor device types. |
pi3/models/layers/camera_head.py |
Removes hardcoded CUDA device type in camera head autocast block. |
pi3/pipe/pi3x_vo.py |
Uses device-aware autocast and CUDA/MPS cache clearing. |
.gitignore |
Updates ignored generated/local files. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import torch | ||
| import argparse | ||
| import os | ||
|
|
||
| # Set MPS fallback before importing other modules that might use torch | ||
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | ||
|
|
| # Set MPS fallback before importing other modules that might use torch | ||
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
| # Set MPS fallback before importing other modules that might use torch | ||
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
| # Set MPS fallback | ||
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
| model_kwargs['with_prior'] = True | ||
|
|
||
| with torch.amp.autocast('cuda', dtype=dtype): | ||
| with torch.amp.autocast(chunk_imgs.device.type, dtype=dtype, enabled=chunk_imgs.device.type != 'cpu'): |
|
Thanks for the PR. The direction looks useful, but I don’t think it is ready to merge yet.
Also, After that, please confirm CUDA/CPU still work in addition to MPS. |
Pull Request: Enable MPS Support and Resolve Hardcoded CUDA Dependencies
Summary
This PR introduces support for Apple Silicon GPUs (MPS) and ensures the codebase is device-agnostic. It resolves
AssertionError: Torch not compiled with CUDA enabledandNotImplementedErrorfor specific operators on macOS, while maintaining full compatibility with CUDA and CPU backends.Changes
1. Modernized Device Selection
example.py,example_mm.py,example_vo.py,demo_gradio.py) now prioritize hardware in the following order:cuda>mps>cpu.2. MPS Operator Fallback
PYTORCH_ENABLE_MPS_FALLBACK=1in all entry point scripts. This ensures that operators not yet natively implemented in MPS (e.g.,_upsample_bicubic2d_aa) automatically fall back to the CPU instead of crashing.3. Device-Agnostic Autocast & Precision
device_type='cuda'intorch.amp.autocastcalls withinPi3,Pi3X, andcamera_head.py. These now dynamically use the input tensor's device type.dtypeselection:bfloat16(if compute capability >= 8) orfloat16.float16(standard for Metal).float32.4. Memory Management
Pi3XVOpipeline andPi3Xmodel to usetorch.mps.empty_cache()when running on Apple Silicon, preventing memory fragmentation during long video processing.5. Bug Fixes
NameErrorinpi3/models/pi3x.pywhereimgswas referenced out of scope; replaced withhidden.device.type.demo_gradio.pyto allow the model to load and run on non-CUDA systems.Technical Details
example.py,example_mm.py,example_vo.py: Updated device/dtype logic and added MPS fallback.demo_gradio.py: Modernized device initialization and inference precision.pi3/models/pi3.py,pi3/models/pi3x.py: Patchedautocastandempty_cache.pi3/models/layers/camera_head.py: Fixed hardcodeddevice_type.pi3/pipe/pi3x_vo.py: Added device-agnostic autocast and cache clearing.Verification Results
Verified on M1 Pro (16GB RAM):
example_mm.pyruns successfully with--interval 50.example.pyruns successfully.example_vo.pyruns successfully.demo_gradio.pyinitializes and loads the model onmps.