Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ ckpts
/utils/utils_ceph.py

resize_demo_imgs.py
img_dir_to_video.py
img_dir_to_video.py.DS_Store
examples/.DS_Store
examples/room/.DS_Store
model.safetensors
32 changes: 25 additions & 7 deletions demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import glob
import gc
import time

# Set MPS fallback
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
Comment on lines +13 to +14

# import spaces # only for web demo

from pi3.utils.geometry import se3_inverse, homogenize_points, depth_edge
Expand Down Expand Up @@ -274,9 +278,12 @@ def run_model(target_dir, model) -> dict:
print(f"Processing images from {target_dir}")

# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

# Move model to device
model = model.to(device)
Expand All @@ -295,9 +302,15 @@ def run_model(target_dir, model) -> dict:

# 3. Infer
print("Running model inference...")
dtype = torch.bfloat16
if device == 'cuda':
dtype = torch.bfloat16 if torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
elif device == 'mps':
dtype = torch.float16
else:
dtype = torch.float32

with torch.no_grad():
with torch.amp.autocast('cuda', dtype=dtype):
with torch.amp.autocast(device, dtype=dtype, enabled=device != 'cpu'):
predictions = model(imgs[None]) # Add batch dimension
predictions['images'] = imgs[None].permute(0, 1, 3, 4, 2)
predictions['conf'] = torch.sigmoid(predictions['conf'])
Expand Down Expand Up @@ -562,9 +575,14 @@ def update_visualization(

if __name__ == '__main__':

device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

print("Initializing and loading Pi3 model...")
print(f"Initializing and loading Pi3 model on {device}...")

model = Pi3.from_pretrained("yyfz233/Pi3")
# model = Pi3()
Expand Down
28 changes: 24 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import torch
import argparse
import os

# Set MPS fallback before importing other modules that might use torch
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

Comment on lines 1 to +7
from pi3.utils.basic import load_images_as_tensor, write_ply
from pi3.utils.geometry import depth_edge
from pi3.models.pi3 import Pi3
Expand All @@ -16,8 +21,16 @@
help="Interval to sample image. Default: 1 for images dir, 10 for video")
parser.add_argument("--ckpt", type=str, default=None,
help="Path to the model checkpoint file. Default: None")
parser.add_argument("--device", type=str, default='cuda',
help="Device to run inference on ('cuda' or 'cpu'). Default: 'cuda'")

if torch.cuda.is_available():
default_device = 'cuda'
elif torch.backends.mps.is_available():
default_device = 'mps'
else:
default_device = 'cpu'

parser.add_argument("--device", type=str, default=default_device,
help=f"Device to run inference on ('cuda', 'mps' or 'cpu'). Default: '{default_device}'")

args = parser.parse_args()
if args.interval < 0:
Expand All @@ -27,6 +40,7 @@
# 1. Prepare model
print(f"Loading model...")
device = torch.device(args.device)
print(f"Using device: {device}")
if args.ckpt is not None:
model = Pi3().to(device).eval()
if args.ckpt.endswith('.safetensors'):
Expand All @@ -46,9 +60,15 @@

# 3. Infer
print("Running model inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
if device.type == 'cuda':
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
elif device.type == 'mps':
dtype = torch.float16
else:
dtype = torch.float32

with torch.no_grad():
with torch.amp.autocast('cuda', dtype=dtype):
with torch.amp.autocast(device.type, dtype=dtype, enabled=device.type != 'cpu'):
res = model(imgs[None]) # Add batch dimension

# 4. process mask
Expand Down
25 changes: 21 additions & 4 deletions example_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import argparse
import numpy as np
import os

# Set MPS fallback before importing other modules that might use torch
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
Comment on lines +6 to +7

from pi3.utils.basic import load_multimodal_data, write_ply
from pi3.utils.geometry import depth_edge, recover_intrinsic_from_rays_d
from pi3.models.pi3x import Pi3X
Expand All @@ -23,8 +27,15 @@
help="Interval to sample image. Default: 1 for images dir, 10 for video")
parser.add_argument("--ckpt", type=str, default=None,
help="Path to the model checkpoint file. Default: None")
parser.add_argument("--device", type=str, default='cuda',
help="Device to run inference on ('cuda' or 'cpu'). Default: 'cuda'")
if torch.cuda.is_available():
default_device = 'cuda'
elif torch.backends.mps.is_available():
default_device = 'mps'
else:
default_device = 'cpu'

parser.add_argument("--device", type=str, default=default_device,
help=f"Device to run inference on ('cuda', 'mps' or 'cpu'). Default: '{default_device}'")

args = parser.parse_args()
if args.interval < 0:
Expand All @@ -33,6 +44,7 @@

# 1. Prepare input data
device = torch.device(args.device)
print(f"Using device: {device}")

# Load optional conditions from .npz
poses = None
Expand Down Expand Up @@ -105,10 +117,15 @@

# 3. Infer
print("Running model inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
if device.type == 'cuda':
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
elif device.type == 'mps':
dtype = torch.float16
else:
dtype = torch.float32

with torch.no_grad():
with torch.amp.autocast('cuda', dtype=dtype):
with torch.amp.autocast(device.type, dtype=dtype, enabled=device.type != 'cpu'):
res = model(
imgs=imgs,
**conditions
Expand Down
26 changes: 22 additions & 4 deletions example_vo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import argparse
import numpy as np
import os

# Set MPS fallback before importing other modules that might use torch
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
Comment on lines +6 to +7

from pi3.utils.basic import load_multimodal_data, write_ply
from pi3.models.pi3x import Pi3X
from pi3.pipe.pi3x_vo import Pi3XVO
Expand All @@ -18,8 +22,16 @@
help="Interval to sample image. Default: 1 for images dir, 10 for video")
parser.add_argument("--ckpt", type=str, default=None,
help="Path to the model checkpoint file. Default: None")
parser.add_argument("--device", type=str, default='cuda',
help="Device to run inference on ('cuda' or 'cpu'). Default: 'cuda'")

if torch.cuda.is_available():
default_device = 'cuda'
elif torch.backends.mps.is_available():
default_device = 'mps'
else:
default_device = 'cpu'

parser.add_argument("--device", type=str, default=default_device,
help=f"Device to run inference on ('cuda', 'mps' or 'cpu'). Default: '{default_device}'")

args = parser.parse_args()
if args.interval < 0:
Expand All @@ -29,6 +41,7 @@
# 1. Prepare model
print(f"Loading model...")
device = torch.device(args.device)
print(f"Using device: {device}")
if args.ckpt is not None:
model = Pi3X().to(device).eval()
if args.ckpt.endswith('.safetensors'):
Expand All @@ -50,7 +63,12 @@

# 3. Infer
print("Running model inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
if device.type == 'cuda':
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
elif device.type == 'mps':
dtype = torch.float16
else:
dtype = torch.float32

with torch.no_grad():
res = pipe(
Expand All @@ -67,4 +85,4 @@
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)

write_ply(res['points'][0][masks].cpu(), imgs[0].permute(0, 2, 3, 1)[masks], args.save_path)
print("Done.")
print("Done.")
2 changes: 1 addition & 1 deletion pi3/models/layers/camera_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, feat, patch_h, patch_w):
feat = feat.view(feat.size(0), -1)

feat = self.more_mlps(feat) # [B, D_]
with torch.amp.autocast(device_type='cuda', enabled=False):
with torch.amp.autocast(device_type=feat.device.type, enabled=False):
out_t = self.fc_t(feat.float()) # [B,3]
out_r = self.fc_rot(feat.float()) # [B,9]
pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
Expand Down
2 changes: 1 addition & 1 deletion pi3/models/pi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def forward(self, imgs):
conf_hidden = self.conf_decoder(hidden, xpos=pos)
camera_hidden = self.camera_decoder(hidden, xpos=pos)

with torch.amp.autocast(device_type='cuda', enabled=False):
with torch.amp.autocast(device_type=point_hidden.device.type, enabled=False):
# local points
point_hidden = point_hidden.float()
ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
Expand Down
4 changes: 2 additions & 2 deletions pi3/models/pi3x.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def encode(
hidden = self.encoder(imgs, is_training=True)["x_norm_patchtokens"]

if self.use_multimodal:
with torch.amp.autocast(device_type='cuda', enabled=False):
with torch.amp.autocast(device_type=hidden.device.type, enabled=False):
if with_prior is True:
p_depth = p_ray = p_pose = 1.0
else:
Expand Down Expand Up @@ -410,7 +410,7 @@ def forward_head(self, hidden, pos, B, N, H, W, patch_h, patch_w):
# decode conf
ret_conf = self.conf_decoder(hidden, xpos=pos)

with torch.amp.autocast(device_type='cuda', enabled=False):
with torch.amp.autocast(device_type=hidden.device.type, enabled=False):
point_feat = ret_point[:, self.patch_start_idx:].float()
xy, z = self._chunked_conv_head(self.point_head, point_feat, patch_h, patch_w)
del point_feat
Expand Down
8 changes: 6 additions & 2 deletions pi3/pipe/pi3x_vo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __call__(self, imgs, chunk_size=16, overlap=6, conf_thre=0.05, inject_condit
model_kwargs['mask_add_ray'] = mask_ray
model_kwargs['with_prior'] = True

with torch.amp.autocast('cuda', dtype=dtype):
with torch.amp.autocast(chunk_imgs.device.type, dtype=dtype, enabled=chunk_imgs.device.type != 'cpu'):
pred = self.model(chunk_imgs, **model_kwargs)

curr_local_depth = pred['local_points'][..., 2]
Expand Down Expand Up @@ -133,7 +133,11 @@ def __call__(self, imgs, chunk_size=16, overlap=6, conf_thre=0.05, inject_condit
if 'poses' in model_kwargs: del model_kwargs['poses']
if 'depths' in model_kwargs: del model_kwargs['depths']
if 'rays' in model_kwargs: del model_kwargs['rays']
torch.cuda.empty_cache()

if imgs.device.type == 'cuda':
torch.cuda.empty_cache()
elif imgs.device.type == 'mps':
torch.mps.empty_cache()

if end_idx == T:
break
Expand Down
Loading