From 45f9d959c82a4f90df37e838fb4e46bcf1aa77e8 Mon Sep 17 00:00:00 2001 From: Zardasht Kaya Date: Fri, 15 May 2026 05:19:20 +0300 Subject: [PATCH] Add support for Apple Silicon --- .gitignore | 5 ++++- demo_gradio.py | 32 +++++++++++++++++++++++++------- example.py | 28 ++++++++++++++++++++++++---- example_mm.py | 25 +++++++++++++++++++++---- example_vo.py | 26 ++++++++++++++++++++++---- pi3/models/layers/camera_head.py | 2 +- pi3/models/pi3.py | 2 +- pi3/models/pi3x.py | 4 ++-- pi3/pipe/pi3x_vo.py | 8 ++++++-- 9 files changed, 106 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 809f21f3..59df09bd 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,7 @@ ckpts /utils/utils_ceph.py resize_demo_imgs.py -img_dir_to_video.py \ No newline at end of file +img_dir_to_video.py.DS_Store +examples/.DS_Store +examples/room/.DS_Store +model.safetensors diff --git a/demo_gradio.py b/demo_gradio.py index 71dfd870..95d1585a 100644 --- a/demo_gradio.py +++ b/demo_gradio.py @@ -9,6 +9,10 @@ import glob import gc import time + +# Set MPS fallback +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + # import spaces # only for web demo from pi3.utils.geometry import se3_inverse, homogenize_points, depth_edge @@ -274,9 +278,12 @@ def run_model(target_dir, model) -> dict: print(f"Processing images from {target_dir}") # Device check - device = "cuda" if torch.cuda.is_available() else "cpu" - if not torch.cuda.is_available(): - raise ValueError("CUDA is not available. Check your environment.") + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" # Move model to device model = model.to(device) @@ -295,9 +302,15 @@ def run_model(target_dir, model) -> dict: # 3. Infer print("Running model inference...") - dtype = torch.bfloat16 + if device == 'cuda': + dtype = torch.bfloat16 if torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16 + elif device == 'mps': + dtype = torch.float16 + else: + dtype = torch.float32 + with torch.no_grad(): - with torch.amp.autocast('cuda', dtype=dtype): + with torch.amp.autocast(device, dtype=dtype, enabled=device != 'cpu'): predictions = model(imgs[None]) # Add batch dimension predictions['images'] = imgs[None].permute(0, 1, 3, 4, 2) predictions['conf'] = torch.sigmoid(predictions['conf']) @@ -562,9 +575,14 @@ def update_visualization( if __name__ == '__main__': - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" - print("Initializing and loading Pi3 model...") + print(f"Initializing and loading Pi3 model on {device}...") model = Pi3.from_pretrained("yyfz233/Pi3") # model = Pi3() diff --git a/example.py b/example.py index 513f7092..2443b6be 100644 --- a/example.py +++ b/example.py @@ -1,5 +1,10 @@ import torch import argparse +import os + +# Set MPS fallback before importing other modules that might use torch +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + from pi3.utils.basic import load_images_as_tensor, write_ply from pi3.utils.geometry import depth_edge from pi3.models.pi3 import Pi3 @@ -16,8 +21,16 @@ help="Interval to sample image. Default: 1 for images dir, 10 for video") parser.add_argument("--ckpt", type=str, default=None, help="Path to the model checkpoint file. Default: None") - parser.add_argument("--device", type=str, default='cuda', - help="Device to run inference on ('cuda' or 'cpu'). Default: 'cuda'") + + if torch.cuda.is_available(): + default_device = 'cuda' + elif torch.backends.mps.is_available(): + default_device = 'mps' + else: + default_device = 'cpu' + + parser.add_argument("--device", type=str, default=default_device, + help=f"Device to run inference on ('cuda', 'mps' or 'cpu'). Default: '{default_device}'") args = parser.parse_args() if args.interval < 0: @@ -27,6 +40,7 @@ # 1. Prepare model print(f"Loading model...") device = torch.device(args.device) + print(f"Using device: {device}") if args.ckpt is not None: model = Pi3().to(device).eval() if args.ckpt.endswith('.safetensors'): @@ -46,9 +60,15 @@ # 3. Infer print("Running model inference...") - dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + if device.type == 'cuda': + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + elif device.type == 'mps': + dtype = torch.float16 + else: + dtype = torch.float32 + with torch.no_grad(): - with torch.amp.autocast('cuda', dtype=dtype): + with torch.amp.autocast(device.type, dtype=dtype, enabled=device.type != 'cpu'): res = model(imgs[None]) # Add batch dimension # 4. process mask diff --git a/example_mm.py b/example_mm.py index 71492ae5..6db2cc60 100644 --- a/example_mm.py +++ b/example_mm.py @@ -2,6 +2,10 @@ import argparse import numpy as np import os + +# Set MPS fallback before importing other modules that might use torch +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + from pi3.utils.basic import load_multimodal_data, write_ply from pi3.utils.geometry import depth_edge, recover_intrinsic_from_rays_d from pi3.models.pi3x import Pi3X @@ -23,8 +27,15 @@ help="Interval to sample image. Default: 1 for images dir, 10 for video") parser.add_argument("--ckpt", type=str, default=None, help="Path to the model checkpoint file. Default: None") - parser.add_argument("--device", type=str, default='cuda', - help="Device to run inference on ('cuda' or 'cpu'). Default: 'cuda'") + if torch.cuda.is_available(): + default_device = 'cuda' + elif torch.backends.mps.is_available(): + default_device = 'mps' + else: + default_device = 'cpu' + + parser.add_argument("--device", type=str, default=default_device, + help=f"Device to run inference on ('cuda', 'mps' or 'cpu'). Default: '{default_device}'") args = parser.parse_args() if args.interval < 0: @@ -33,6 +44,7 @@ # 1. Prepare input data device = torch.device(args.device) + print(f"Using device: {device}") # Load optional conditions from .npz poses = None @@ -105,10 +117,15 @@ # 3. Infer print("Running model inference...") - dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + if device.type == 'cuda': + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + elif device.type == 'mps': + dtype = torch.float16 + else: + dtype = torch.float32 with torch.no_grad(): - with torch.amp.autocast('cuda', dtype=dtype): + with torch.amp.autocast(device.type, dtype=dtype, enabled=device.type != 'cpu'): res = model( imgs=imgs, **conditions diff --git a/example_vo.py b/example_vo.py index 2e9f9c20..0f45201b 100644 --- a/example_vo.py +++ b/example_vo.py @@ -2,6 +2,10 @@ import argparse import numpy as np import os + +# Set MPS fallback before importing other modules that might use torch +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + from pi3.utils.basic import load_multimodal_data, write_ply from pi3.models.pi3x import Pi3X from pi3.pipe.pi3x_vo import Pi3XVO @@ -18,8 +22,16 @@ help="Interval to sample image. Default: 1 for images dir, 10 for video") parser.add_argument("--ckpt", type=str, default=None, help="Path to the model checkpoint file. Default: None") - parser.add_argument("--device", type=str, default='cuda', - help="Device to run inference on ('cuda' or 'cpu'). Default: 'cuda'") + + if torch.cuda.is_available(): + default_device = 'cuda' + elif torch.backends.mps.is_available(): + default_device = 'mps' + else: + default_device = 'cpu' + + parser.add_argument("--device", type=str, default=default_device, + help=f"Device to run inference on ('cuda', 'mps' or 'cpu'). Default: '{default_device}'") args = parser.parse_args() if args.interval < 0: @@ -29,6 +41,7 @@ # 1. Prepare model print(f"Loading model...") device = torch.device(args.device) + print(f"Using device: {device}") if args.ckpt is not None: model = Pi3X().to(device).eval() if args.ckpt.endswith('.safetensors'): @@ -50,7 +63,12 @@ # 3. Infer print("Running model inference...") - dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + if device.type == 'cuda': + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + elif device.type == 'mps': + dtype = torch.float16 + else: + dtype = torch.float32 with torch.no_grad(): res = pipe( @@ -67,4 +85,4 @@ os.makedirs(os.path.dirname(args.save_path), exist_ok=True) write_ply(res['points'][0][masks].cpu(), imgs[0].permute(0, 2, 3, 1)[masks], args.save_path) - print("Done.") + print("Done.") \ No newline at end of file diff --git a/pi3/models/layers/camera_head.py b/pi3/models/layers/camera_head.py index 7d844f7b..720f9744 100644 --- a/pi3/models/layers/camera_head.py +++ b/pi3/models/layers/camera_head.py @@ -56,7 +56,7 @@ def forward(self, feat, patch_h, patch_w): feat = feat.view(feat.size(0), -1) feat = self.more_mlps(feat) # [B, D_] - with torch.amp.autocast(device_type='cuda', enabled=False): + with torch.amp.autocast(device_type=feat.device.type, enabled=False): out_t = self.fc_t(feat.float()) # [B,3] out_r = self.fc_rot(feat.float()) # [B,9] pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device) diff --git a/pi3/models/pi3.py b/pi3/models/pi3.py index 917c6ccb..96add4a4 100644 --- a/pi3/models/pi3.py +++ b/pi3/models/pi3.py @@ -189,7 +189,7 @@ def forward(self, imgs): conf_hidden = self.conf_decoder(hidden, xpos=pos) camera_hidden = self.camera_decoder(hidden, xpos=pos) - with torch.amp.autocast(device_type='cuda', enabled=False): + with torch.amp.autocast(device_type=point_hidden.device.type, enabled=False): # local points point_hidden = point_hidden.float() ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) diff --git a/pi3/models/pi3x.py b/pi3/models/pi3x.py index 82f8a11a..8ebfd5d0 100644 --- a/pi3/models/pi3x.py +++ b/pi3/models/pi3x.py @@ -291,7 +291,7 @@ def encode( hidden = self.encoder(imgs, is_training=True)["x_norm_patchtokens"] if self.use_multimodal: - with torch.amp.autocast(device_type='cuda', enabled=False): + with torch.amp.autocast(device_type=hidden.device.type, enabled=False): if with_prior is True: p_depth = p_ray = p_pose = 1.0 else: @@ -410,7 +410,7 @@ def forward_head(self, hidden, pos, B, N, H, W, patch_h, patch_w): # decode conf ret_conf = self.conf_decoder(hidden, xpos=pos) - with torch.amp.autocast(device_type='cuda', enabled=False): + with torch.amp.autocast(device_type=hidden.device.type, enabled=False): point_feat = ret_point[:, self.patch_start_idx:].float() xy, z = self._chunked_conv_head(self.point_head, point_feat, patch_h, patch_w) del point_feat diff --git a/pi3/pipe/pi3x_vo.py b/pi3/pipe/pi3x_vo.py index 983fe00a..88c98fd4 100644 --- a/pi3/pipe/pi3x_vo.py +++ b/pi3/pipe/pi3x_vo.py @@ -75,7 +75,7 @@ def __call__(self, imgs, chunk_size=16, overlap=6, conf_thre=0.05, inject_condit model_kwargs['mask_add_ray'] = mask_ray model_kwargs['with_prior'] = True - with torch.amp.autocast('cuda', dtype=dtype): + with torch.amp.autocast(chunk_imgs.device.type, dtype=dtype, enabled=chunk_imgs.device.type != 'cpu'): pred = self.model(chunk_imgs, **model_kwargs) curr_local_depth = pred['local_points'][..., 2] @@ -133,7 +133,11 @@ def __call__(self, imgs, chunk_size=16, overlap=6, conf_thre=0.05, inject_condit if 'poses' in model_kwargs: del model_kwargs['poses'] if 'depths' in model_kwargs: del model_kwargs['depths'] if 'rays' in model_kwargs: del model_kwargs['rays'] - torch.cuda.empty_cache() + + if imgs.device.type == 'cuda': + torch.cuda.empty_cache() + elif imgs.device.type == 'mps': + torch.mps.empty_cache() if end_idx == T: break