From 44bc4b2cff00172ba9a8b7fa2bf088f792aafdc2 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 00:57:32 -0700 Subject: [PATCH 01/32] support finetuning --- examples/convert_jax_model_to_pytorch.py | 6 + scripts/train_pytorch.py | 619 +++++++++++++++++++++ scripts/train_single_example.py | 548 ++++++++++++++++++ src/openpi/models/pi0.py | 20 +- src/openpi/models_pytorch/gemma_pytorch.py | 124 ++++- src/openpi/models_pytorch/pi0_pytorch.py | 17 +- src/openpi/training/config.py | 24 +- src/openpi/training/data_loader.py | 1 + uv.lock | 356 +++++++----- 9 files changed, 1565 insertions(+), 150 deletions(-) create mode 100644 scripts/train_pytorch.py create mode 100644 scripts/train_single_example.py diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index 03df445..8db8f1d 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -643,6 +643,12 @@ def __init__(self): action_horizon=10, pi05=True, ) + elif "pi05_base" in checkpoint_dir: + pi0_config = Pi0Config( + action_dim=32, + action_horizon=50, + pi05=True, + ) else: print("Warning: Could not determine PI0 config from checkpoint path. Using base config.") pi0_config = Pi0Config( diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py new file mode 100644 index 0000000..1a46fe1 --- /dev/null +++ b/scripts/train_pytorch.py @@ -0,0 +1,619 @@ +""" +PyTorch training entrypoint for PI0 with multi-GPU and multi-node (DDP) support. +This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs +entirely in PyTorch using the `PI0Pytorch` model and your existing config/data +pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. +Key features +- Uses the same TrainConfig/tyro CLI as the JAX script (see available configs in + `src/openpi/training/config.py`). +- Supports multi-GPU and multi-node training via DistributedDataParallel (DDP). +- Cosine LR with warmup (parameters read from the selected config). +- AdamW optimizer and gradient clipping. +- Comprehensive checkpoint saving and resume mechanism with configurable intervals. +- Checkpoints saved on rank 0 to `config.checkpoint_dir//` containing model, optimizer, and metadata. +- Memory optimizations: mixed precision training, gradient accumulation, and efficient data handling. +Requirements +- PyTorch >= 2.0, torch.distributed (NCCL for CUDA, Gloo for CPU). +- Multiple GPUs for DDP (optional). +- Network connectivity between nodes for multi-node training. +Usage +Single GPU: + python scripts/train_pytorch.py --exp_name --ckpt_save_interval + Example: + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test +Multi-GPU (single node): + torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name + Example: + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test +Multi-Node Training: + # On master node (node 0): + torchrun --nnodes= --nproc_per_node= --rdzv_id= --rdzv_backend=c10d --rdzv_endpoint=: scripts/train_pytorch.py --exp_name + + # On worker nodes (node 1, 2, ...): + torchrun --nnodes= --nproc_per_node= --rdzv_id= --rdzv_backend=c10d --rdzv_endpoint=: scripts/train_pytorch.py --exp_name + + Example (2 nodes, 4 GPUs each): + # Master node (192.168.1.100): + torchrun --nnodes=2 --nproc_per_node=4 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=192.168.1.100:29400 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_multi_node + + # Worker node (192.168.1.101): + torchrun --nnodes=2 --nproc_per_node=4 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=192.168.1.100:29400 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_multi_node +Multi-Node Setup Requirements: +1. Network connectivity: All nodes must be able to communicate on the specified port +2. Shared filesystem: All nodes must have access to the same dataset and checkpoint directories +3. Environment consistency: Same Python environment and dependencies on all nodes +4. Firewall configuration: Ensure the rendezvous port (e.g., 29400) is open between nodes +5. SSH access: Nodes should be able to SSH to each other (for torchrun coordination) +Environment Variables for Multi-Node: +- MASTER_ADDR: IP address of the master node (auto-set by torchrun) +- MASTER_PORT: Port for rendezvous (auto-set by torchrun) +- WORLD_SIZE: Total number of processes across all nodes +- RANK: Global rank of the process (0 to WORLD_SIZE-1) +- LOCAL_RANK: Local rank within the node (0 to nproc_per_node-1) +- NODE_RANK: Rank of the node (0 to nnodes-1) +Checkpoint Parameters: +- --ckpt_save_interval: Override the checkpoint save interval from config (e.g., --save_interval 500) +- --resume: Resume training from the latest checkpoint in the checkpoint directory +- --overwrite: Overwrite existing checkpoint directory (cannot be used with --resume) +Memory Optimization Parameters: +- --gradient_accumulation_steps: Number of steps to accumulate gradients (default: 1) +- --mixed_precision: Enable mixed precision training (default: True) +- --max_memory_usage: Maximum GPU memory usage in GB (default: None, auto-detect) +Notes +- The global batch size must be divisible by world size (number of processes). +- The data pipeline and transforms are identical to the JAX version and are controlled + by the selected TrainConfig (e.g., `LeRobot*` configs for real datasets or `FakeDataConfig`). +- Supports Weights & Biases (wandb) logging for experiment tracking and visualization. +- Checkpoints include model state, optimizer state, and training metadata for complete resume capability. +- For optimal multi-node performance, ensure high-bandwidth network connectivity (e.g., InfiniBand). +- Monitor GPU utilization and network bandwidth during multi-node training. +- Memory optimizations can significantly reduce GPU memory usage while maintaining training quality. +""" +import argparse +import dataclasses +import logging +import os +import platform +import time +from dataclasses import dataclass +from typing import Any, Dict, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader as TorchDataLoader +from torch.utils.data.distributed import DistributedSampler +import wandb +from tqdm import tqdm + +import openpi.training.config as _config +import openpi.training.data_loader as _data +import openpi.models.model as _model +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch +from openpi.models.pi0_config import Pi0Config + + +def init_logging(): + level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} + + class CustomFormatter(logging.Formatter): + def format(self, record): + record.levelname = level_mapping.get(record.levelname, record.levelname) + return super().format(record) + + formatter = CustomFormatter( + fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", + datefmt="%H:%M:%S", + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + if not logger.handlers: + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + else: + logger.handlers[0].setFormatter(formatter) + + +def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): + """Initialize wandb logging.""" + if not enabled: + wandb.init(mode="disabled") + return + + ckpt_dir = config.checkpoint_dir + if not ckpt_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") + + if resuming: + run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() + wandb.init(id=run_id, resume="must", project=config.project_name) + else: + wandb.init( + name=config.exp_name, + config=dataclasses.asdict(config), + project=config.project_name, + ) + (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) + + +def setup_ddp(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + use_ddp = world_size > 1 + if use_ddp and not dist.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method="env://") + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(device) + return use_ddp, local_rank, device + + +def cleanup_ddp(): + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +def set_seed(seed: int, local_rank: int): + torch.manual_seed(seed + local_rank) + np.random.seed(seed + local_rank) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed + local_rank) + + +def build_datasets(config: _config.TrainConfig): + # Reuse existing dataset + transforms pipeline + data_conf = config.data.create(config.assets_dirs, config.model) + dataset = _data.create_torch_dataset(data_conf, config.model.action_horizon, config.model) + print(f"data_conf: {data_conf}") + dataset = _data.transform_dataset(dataset, data_conf) + return dataset, data_conf + + +def collate_to_numpy(batch_list: list[Dict[str, Any]]) -> Dict[str, Any]: + # Recursively stack leaves with numpy + def stack_leaf(*xs): + return np.stack([np.asarray(x) for x in xs], axis=0) + + # Memory-efficient collation + result = torch.utils.data.default_collate(batch_list) if not isinstance(batch_list[0], dict) else _tree_map_multi(stack_leaf, batch_list) + + # Clear batch list from memory + del batch_list + + return result + + +def _tree_map_multi(func, batch_list): + # batch_list is a list of dicts with same structure; reduce by zipping leaves + def recurse(keys, items): + if isinstance(items[0], dict): + return {k: recurse(keys + [k], [it[k] for it in items]) for k in items[0].keys()} + return func(*items) + return recurse([], batch_list) + + +def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + # Maintain canonical image key order + image_keys = _model.IMAGE_KEYS + import jax + + # Memory-efficient conversion: convert to torch tensors and move to device in one step + batch = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(device), batch) + + # Convert to float32 for memory efficiency (avoid float64) + batch['state'] = batch['state'].to(dtype=torch.float32) + batch['actions'] = batch['actions'].to(dtype=torch.float32) + + # Clear numpy arrays from memory if they exist + del jax + + return batch + + +def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_interval=None, ema_model=None): + """Save a checkpoint with model state, optimizer state, EMA state, and metadata.""" + if not is_main: + return + + # Use ckpt_save_interval if provided, otherwise use config.save_interval + save_interval = ckpt_save_interval if ckpt_save_interval is not None else config.save_interval + + # Only save if it's time to save or if it's the final step + if (global_step % save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: + ckpt_dir = os.path.join(config.checkpoint_dir, f"{global_step}") + os.makedirs(ckpt_dir, exist_ok=True) + + # Save model state + state_dict = (model.module if isinstance(model, DDP) else model).state_dict() + torch.save(state_dict, os.path.join(ckpt_dir, "pytorch_model.pt")) + + # Save optimizer state + torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer.pt")) + + # Save EMA state if available + if ema_model is not None: + torch.save(ema_model.state_dict(), os.path.join(ckpt_dir, "ema_model.pt")) + + # Save training metadata + metadata = { + "global_step": global_step, + "config": dataclasses.asdict(config), + "timestamp": time.time(), + } + torch.save(metadata, os.path.join(ckpt_dir, "metadata.pt")) + + logging.info(f"Saved checkpoint at step {global_step} -> {ckpt_dir}") + + # Log checkpoint to wandb + if config.wandb_enabled: + wandb.log({"checkpoint_step": global_step}, step=global_step) + + +def load_checkpoint(model, optimizer, config, device, ema_model=None): + """Load the latest checkpoint and return the global step.""" + checkpoint_steps = [] + for d in config.checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + if not checkpoint_steps: + raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir}") + + latest_step = max(checkpoint_steps) + ckpt_dir = os.path.join(config.checkpoint_dir, f"{latest_step}") + + # Load model state + model_state_dict = torch.load(os.path.join(ckpt_dir, "pytorch_model.pt"), map_location=device) + (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) + + # Load optimizer state + optimizer_state_dict = torch.load(os.path.join(ckpt_dir, "optimizer.pt"), map_location=device) + optimizer.load_state_dict(optimizer_state_dict) + + # Load EMA state if available + if ema_model is not None and os.path.exists(os.path.join(ckpt_dir, "ema_model.pt")): + ema_state_dict = torch.load(os.path.join(ckpt_dir, "ema_model.pt"), map_location=device) + ema_model.load_state_dict(ema_state_dict) + logging.info(f"Loaded EMA state from checkpoint") + + # Load metadata + metadata = torch.load(os.path.join(ckpt_dir, "metadata.pt"), map_location=device) + + logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") + return metadata["global_step"] + + +def get_latest_checkpoint_step(config): + """Get the latest checkpoint step number.""" + checkpoint_steps = [] + for d in config.checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + return max(checkpoint_steps) if checkpoint_steps else None + + +def setup_memory_optimizations(model, device, enable_gradient_checkpointing=False): + """Setup memory optimization techniques for the model.""" + if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): + model.gradient_checkpointing_enable() + logging.info("Enabled gradient checkpointing for memory optimization") + + # Enable memory efficient attention if available + if hasattr(model, 'config') and hasattr(model.config, 'attention_mode'): + model.config.attention_mode = 'flash_attention_2' + logging.info("Enabled Flash Attention 2 for memory efficiency") + + # Set memory efficient settings + if torch.cuda.is_available(): + # Enable memory efficient algorithms + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + # Set memory fraction if needed + if device.index is not None: + torch.cuda.empty_cache() + logging.info(f"Cleared CUDA cache for device {device.index}") + + +def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False): + use_ddp, local_rank, device = setup_ddp() + is_main = (not use_ddp) or (dist.get_rank() == 0) + set_seed(config.seed, local_rank) + + # Memory optimization: Set memory fraction if specified + if max_memory_usage is not None and torch.cuda.is_available(): + torch.cuda.set_per_process_memory_fraction(max_memory_usage / torch.cuda.get_device_properties(device).total_memory * 1e-9) + + # Initialize checkpoint directory and wandb + resuming = False + if config.resume: + # Check if checkpoint directory exists and has checkpoints + if config.checkpoint_dir.exists(): + latest_step = get_latest_checkpoint_step(config) + if latest_step is not None: + resuming = True + logging.info(f"Resuming from checkpoint directory: {config.checkpoint_dir} at step {latest_step}") + else: + raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir} for resume") + else: + raise FileNotFoundError(f"Checkpoint directory {config.checkpoint_dir} does not exist for resume") + elif config.overwrite and config.checkpoint_dir.exists(): + import shutil + shutil.rmtree(config.checkpoint_dir) + logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") + + # Create checkpoint directory + config.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Initialize wandb (only on main process) + if is_main: + init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) + + # Build dataset + sampler + loader + dataset, data_conf = build_datasets(config) + sampler = None + if use_ddp: + sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True, drop_last=True) + + # Reduce batch size for gradient accumulation + effective_batch_size = config.batch_size // (dist.get_world_size() if use_ddp else 1) + + # Memory-efficient data loading with reduced pin_memory for large datasets + pin_memory = True + if effective_batch_size > 16: # Reduce pin_memory for large batches + pin_memory = False + logging.info("Disabled pin_memory for large batch size to reduce memory usage") + + loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) + + # Log sample images to wandb on first batch + if is_main and config.wandb_enabled and not resuming: + sample_batch = next(iter(loader)) + sample_batch = batch_to_torch(sample_batch, device) + + # Create sample images for wandb + images_to_log = [] + # Get batch size from the first image tensor + batch_size = next(iter(sample_batch['image'].values())).shape[0] + for i in range(min(5, batch_size)): + # Concatenate all camera views horizontally for this batch item + img_concatenated = torch.cat([img[i] for img in sample_batch['image'].values()], axis=1) + img_concatenated = img_concatenated.cpu().numpy() + images_to_log.append(wandb.Image(img_concatenated)) + + wandb.log({"camera_views": images_to_log}, step=0) + + # Clear sample batch from memory + del sample_batch, images_to_log + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Reset the loader iterator + loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) + + # Build model + if not isinstance(config.model, Pi0Config): + # Convert dataclass to Pi0Config if needed + model_cfg = Pi0Config( + action_dim=config.model.action_dim, + action_horizon=config.model.action_horizon, + max_token_len=config.model.max_token_len, + paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), + action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), + pi05=getattr(config.model, "pi05", False), + ) + else: + model_cfg = config.model + + model = PI0Pytorch(model_cfg).to(device) + + # Apply memory optimizations + setup_memory_optimizations(model, device, enable_gradient_checkpointing) + + if use_ddp: + model = DDP(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=False) + + # Load weights from weight_loader if specified (for fine-tuning) + if isinstance(config.weight_loader, str): + weight_path = config.weight_loader + logging.info(f"Loading weights from: {weight_path}") + + model_path = os.path.join(weight_path, "model.safetensors") + from safetensors.torch import load_model + load_model((model.module if isinstance(model, DDP) else model), model_path) + logging.info(f"Loaded PyTorch weights from {weight_path}") + + # Optimizer + learning rate schedule from config + warmup_steps = config.lr_schedule.warmup_steps + peak_lr = config.lr_schedule.peak_lr + decay_steps = config.lr_schedule.decay_steps + end_lr = config.lr_schedule.decay_lr + + # Create optimizer with config parameters + optim = torch.optim.AdamW( + model.parameters(), + lr=peak_lr, + betas=(config.optimizer.b1, config.optimizer.b2), + eps=config.optimizer.eps, + weight_decay=config.optimizer.weight_decay + ) + + # Initialize EMA if specified in config + ema_model = None + if config.ema_decay is not None: + ema_model = PI0Pytorch(model_cfg).to(device) + ema_model.load_state_dict(model.state_dict()) + ema_model.eval() + logging.info(f"Initialized EMA with decay {config.ema_decay}") + + # Load checkpoint if resuming + global_step = 0 + if resuming: + global_step = load_checkpoint(model, optim, config, device, ema_model) + logging.info(f"Resumed training from step {global_step}") + + def lr_schedule(step: int): + if step < warmup_steps: + return peak_lr * (step + 1) / warmup_steps + # cosine decay + progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) + cos = 0.5 * (1 + np.cos(np.pi * progress)) + return end_lr + (peak_lr - end_lr) * cos + + # Enable mixed precision training for memory optimization + scaler = torch.amp.GradScaler(enabled=mixed_precision and torch.cuda.is_available()) + + model.train() + start_time = time.time() + infos = [] # Collect stats over log interval + if is_main: + logging.info(f"Running on: {platform.node()} | world_size={dist.get_world_size() if use_ddp else 1}") + logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}") + logging.info(f"Memory optimizations: gradient_accumulation_steps={gradient_accumulation_steps}, mixed_precision={mixed_precision}, gradient_checkpointing={enable_gradient_checkpointing}") + logging.info(f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}") + logging.info(f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}") + if config.ema_decay is not None: + logging.info(f"EMA decay: {config.ema_decay}") + + # Training loop - iterate until we reach num_train_steps + pbar = tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None + + while global_step < config.num_train_steps: + if use_ddp: + sampler.set_epoch(global_step // len(loader)) + + for batch in loader: + # Check if we've reached the target number of steps + if global_step >= config.num_train_steps: + break + + # Convert dict batch directly to torch tensors (bypass Observation.from_dict for PyTorch) + batch = batch_to_torch(batch, device) + actions = batch["actions"] + + # Update LR + for pg in optim.param_groups: + pg["lr"] = lr_schedule(global_step) + + # Forward pass with mixed precision + observation = _model.Observation.from_dict(batch) + with torch.amp.autocast('cuda', enabled=mixed_precision and torch.cuda.is_available()): + losses = model(observation, actions) + loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation + + # Backward pass with gradient scaling + scaler.scale(loss).backward() + + # Gradient accumulation logic + if (global_step + 1) % gradient_accumulation_steps == 0: + # Unscale gradients for clipping + scaler.unscale_(optim) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) + + # Optimizer step + scaler.step(optim) + scaler.update() + optim.zero_grad(set_to_none=True) + + # Update EMA if enabled + if ema_model is not None: + with torch.no_grad(): + for param, ema_param in zip(model.parameters(), ema_model.parameters()): + ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) + + # Collect stats (only on accumulation steps) + if (global_step + 1) % gradient_accumulation_steps == 0 and is_main: + infos.append({ + "loss": loss.item() * gradient_accumulation_steps, # Unscale for logging + "learning_rate": optim.param_groups[0]['lr'], + }) + + if is_main and (global_step % config.log_interval == 0) and (global_step + 1) % gradient_accumulation_steps == 0: + elapsed = time.time() - start_time + + # Average stats over log interval + avg_loss = sum(info["loss"] for info in infos) / len(infos) + avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) + + logging.info(f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s") + + # Log to wandb + if config.wandb_enabled: + wandb.log({ + "loss": avg_loss, + "learning_rate": avg_lr, + "step": global_step, + "time_per_step": elapsed / config.log_interval, + }, step=global_step) + + start_time = time.time() + infos = [] # Reset stats collection + + # Save checkpoint using the new mechanism + save_checkpoint(model, optim, global_step, config, is_main, ckpt_save_interval, ema_model) + + global_step += 1 + + # Update progress bar + if pbar is not None: + pbar.update(1) + pbar.set_postfix({ + 'loss': f'{loss.item() * gradient_accumulation_steps:.4f}', + 'lr': f'{optim.param_groups[0]["lr"]:.2e}', + 'step': global_step + }) + + # Memory cleanup after each batch + del batch, actions, observation, losses, loss + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Close progress bar + if pbar is not None: + pbar.close() + + # Finish wandb run + if is_main and config.wandb_enabled: + wandb.finish() + + cleanup_ddp() + + +def main(): + init_logging() + config = _config.cli() + + # Parse additional command line arguments for memory optimization + import argparse + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--ckpt_save_interval", type=int, default=None, + help="Interval for saving checkpoints (overrides config.save_interval)") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of steps to accumulate gradients (default: 1)") + parser.add_argument("--mixed_precision", action="store_true", default=False, + help="Enable mixed precision training (default: True)") + parser.add_argument("--no_mixed_precision", action="store_true", default=True, + help="Disable mixed precision training") + parser.add_argument("--max_memory_usage", type=float, default=None, + help="Maximum GPU memory usage in GB (default: None, auto-detect)") + parser.add_argument("--enable_gradient_checkpointing", action="store_true", default=False, + help="Enable gradient checkpointing for memory optimization") + args, _ = parser.parse_known_args() + + # Handle mixed precision flag + mixed_precision = args.mixed_precision and not args.no_mixed_precision + + train_loop(config, + ckpt_save_interval=args.ckpt_save_interval, + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=mixed_precision, + max_memory_usage=args.max_memory_usage, + enable_gradient_checkpointing=args.enable_gradient_checkpointing) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_single_example.py b/scripts/train_single_example.py new file mode 100644 index 0000000..2a23d48 --- /dev/null +++ b/scripts/train_single_example.py @@ -0,0 +1,548 @@ +""" +Train on a single example for debugging JAX vs PyTorch comparison. +This script creates a deterministic dataset with one example and trains on it +to help debug differences between JAX and PyTorch implementations. +""" + +import logging +import numpy as np +import torch +import jax +import jax.numpy as jnp +import flax.nnx as nnx +import flax + +from openpi.models import model as _model +from openpi.models.pi0_config import Pi0Config +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch + + +def setup_logging(): + """Setup logging for debugging.""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' + ) + + +def create_fixed_example(): + """Create a fixed example for debugging.""" + np.random.seed(42) + + batch_size = 1 + action_dim = 32 + action_horizon = 10 + image_size = 224 + max_token_len = 48 + + # Create fixed images + images = {} + for key in _model.IMAGE_KEYS: + img = np.zeros((batch_size, image_size, image_size, 3), dtype=np.float32) + + # Simple gradient pattern + for i in range(image_size): + for j in range(image_size): + val = (i + j) / (2 * image_size) * 2 - 1 + img[0, i, j, :] = [val, val * 0.5, val * 0.25] + + images[key] = img + + # Create fixed state and actions + state = np.random.randn(batch_size, action_dim).astype(np.float32) * 0.1 + actions = np.random.randn(batch_size, action_horizon, action_dim).astype(np.float32) * 0.1 + + # Create fixed language tokens + tokenized_prompt = np.random.randint(0, 1000, (batch_size, max_token_len), dtype=np.int32) + tokenized_prompt_mask = np.ones((batch_size, max_token_len), dtype=bool) + + # Create image masks + image_masks = {key: np.ones(batch_size, dtype=bool) for key in _model.IMAGE_KEYS} + + return { + "image": images, + "image_mask": image_masks, + "state": state, + "actions": actions, + "tokenized_prompt": tokenized_prompt, + "tokenized_prompt_mask": tokenized_prompt_mask, + } + + +def create_fixed_noise_and_time(batch_size, action_horizon, action_dim): + """Create fixed noise and time values for deterministic comparison.""" + np.random.seed(42) # Use same seed for consistency + + # Create fixed noise + noise = np.random.randn(batch_size, action_horizon, action_dim).astype(np.float32) * 0.1 + + # Create fixed time values (beta distribution like in the models) + time_beta = np.random.beta(1.5, 1.0, batch_size).astype(np.float32) + time = time_beta * 0.999 + 0.001 + + return noise, time + + +def test_pytorch_single_example(noise, time): + """Test PyTorch training on single example.""" + print("\n=== Testing PyTorch on Single Example ===") + + # Create model + config = Pi0Config(action_dim=32, action_horizon=10, pi05=True) + model = PI0Pytorch(config) + + # Load pre-trained weights + weight_path = "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2/model.safetensors" + print(f"Loading PyTorch weights from: {weight_path}") + + from safetensors.torch import load_model + load_model(model, weight_path) + + # Create fixed example + example = create_fixed_example() + + # Convert to PyTorch tensors + pytorch_example = {} + for key, value in example.items(): + if key == "image": + # Convert channels-last [B, H, W, C] to channels-first [B, C, H, W] for PyTorch + pytorch_example[key] = {} + for k, v in value.items(): + # v is [B, H, W, C], convert to [B, C, H, W] + v_tensor = torch.from_numpy(v) + v_tensor = v_tensor.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + pytorch_example[key][k] = v_tensor + elif key == "image_mask": + pytorch_example[key] = {k: torch.from_numpy(v) for k, v in value.items()} + else: + pytorch_example[key] = torch.from_numpy(value) + + # Convert noise and time to PyTorch tensors + noise_tensor = torch.from_numpy(noise) + time_tensor = torch.from_numpy(time) + + # Create observation + observation = _model.Observation.from_dict(pytorch_example) + actions = pytorch_example["actions"] + + print(f"Observation state shape: {observation.state.shape}") + print(f"Observation state dtype: {observation.state.dtype}") + print(f"Actions shape: {actions.shape}") + print(f"Actions dtype: {actions.dtype}") + print(f"Noise shape: {noise_tensor.shape}, dtype: {noise_tensor.dtype}") + print(f"Time shape: {time_tensor.shape}, dtype: {time_tensor.dtype}") + + # Test forward pass with fixed noise and time + model.eval() + with torch.no_grad(): + #try: + losses = model(observation, actions, noise=noise_tensor, time=time_tensor) + print(f"PyTorch forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + # mean_loss = losses.to(torch.float32).mean().item() + # print(f"Mean loss: {mean_loss:.6f}") + return True, losses + # except Exception as e: + # print(f"PyTorch forward pass failed: {e}") + # return False, None + + +def test_jax_single_example(noise, time, debug_single_layer=False): + """Test JAX training on single example.""" + print("\n=== Testing JAX on Single Example ===") + + # Create model + config = Pi0Config(action_dim=32, action_horizon=10, pi05=True) + if debug_single_layer: + print("šŸ”§ Debug mode: Using only 1 encoder layer") + + # Create a custom model with modified siglip depth for debugging + if debug_single_layer: + # Import the Pi0 model class + from openpi.models.pi0 import Pi0 + import openpi.models.gemma as _gemma + import openpi.models.siglip as _siglip + import flax.nnx.bridge as nnx_bridge + + # Create the model manually with custom siglip variant + rng = jax.random.key(42) + rngs = flax.nnx.Rngs(rng) + + paligemma_config = _gemma.get_config(config.paligemma_variant) + action_expert_config = _gemma.get_config(config.action_expert_variant) + + # Create LLM + llm = nnx_bridge.ToNNX( + _gemma.Module( + configs=[paligemma_config, action_expert_config], + embed_dtype=config.dtype, + adarms=config.pi05, + ) + ) + llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False]) + + # Create custom siglip model with depth=1 + # We'll use the same variant but override the depth parameter + siglip_params = _siglip.decode_variant("So400m/14") + siglip_params["depth"] = 1 # Override depth to 1 for debugging + + img = nnx_bridge.ToNNX( + _siglip.Module( + num_classes=paligemma_config.width, + variant=None, # Don't use variant, use explicit params + pool_type="none", + scan=False, # Disable scan for single layer + dtype_mm=config.dtype, + **siglip_params, # Pass the modified parameters + ) + ) + img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) + + # Create the full model + model = Pi0(config, rngs) + # Replace the siglip model with our custom one + model.PaliGemma.img = img + + print("šŸ”§ Created single-layer SigLIP model (depth=1) for debugging...") + else: + rng = jax.random.key(42) + model = config.create(rng) + + # Load pre-trained weights + weight_path = "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" + print(f"Loading JAX weights from: {weight_path}") + + # try: + # Use the same approach as in policy_config.py + params = _model.restore_params(weight_path, dtype=jnp.bfloat16) + + # Filter params to only include the first encoder layer for debugging + if debug_single_layer: + filtered_params = {} + + # The parameters are nested, so we need to traverse the structure + def filter_nested_params(params_dict, key_path=""): + result = {} + for key, value in params_dict.items(): + current_path = f"{key_path}.{key}" if key_path else key + + if isinstance(value, dict): + # Recursive case - traverse deeper + filtered_sub = filter_nested_params(value, current_path) + if filtered_sub: # Only include if there are sub-parameters + result[key] = filtered_sub + else: + # Leaf case - check if this parameter should be included + if 'Transformer' in current_path: + # Only keep the first encoder block (encoderblock_0) and encoder_norm + if 'encoderblock_0' in current_path or 'encoder_norm' in current_path: + result[key] = value + else: + # Keep all non-Transformer params + result[key] = value + return result + + filtered_params = filter_nested_params(params) + params_to_use = filtered_params + print("āœ… JAX weights loaded successfully (first layer only)!") + print("āš ļø Note: Using So400m variant with depth=1 (modified from depth=27)") + print("āš ļø Only the first layer weights will be used, others will be randomly initialized") + + # Debug: Show what parameters we have + print(f"šŸ“‹ Available parameters for single-layer model:") + transformer_params = [] + for key in sorted(params_to_use.keys()): + if 'Transformer' in key: + transformer_params.append(key) + print(f" {key}: {params_to_use[key].shape}") + + if not transformer_params: + print(" No Transformer parameters found! Let's see all keys:") + for key in sorted(params_to_use.keys())[:20]: # Show first 20 keys + print(f" {key}") + + # Let's also check if PaliGemma has nested structure + if 'PaliGemma' in params_to_use: + print(" Checking PaliGemma structure:") + paligemma_params = params_to_use['PaliGemma'] + if hasattr(paligemma_params, 'keys'): + for subkey in sorted(paligemma_params.keys()): + print(f" PaliGemma.{subkey}") + if hasattr(paligemma_params[subkey], 'keys'): + for subsubkey in sorted(paligemma_params[subkey].keys()): + print(f" PaliGemma.{subkey}.{subsubkey}") + if hasattr(paligemma_params[subkey][subsubkey], 'keys'): + for subsubsubkey in sorted(paligemma_params[subkey][subsubkey].keys()): + if 'Transformer' in subsubsubkey: + print(f" PaliGemma.{subkey}.{subsubkey}.{subsubsubkey}") + + # The issue is that with scan=False, the model expects different parameter names + # We need to map from encoderblock_0 to encoderblock in the nested structure + def adapt_nested_params(params_dict, key_path=""): + result = {} + for key, value in params_dict.items(): + current_path = f"{key_path}.{key}" if key_path else key + + if isinstance(value, dict): + # Recursive case - traverse deeper + result[key] = adapt_nested_params(value, current_path) + else: + # Leaf case - adapt the key if needed + new_key = key + if 'Transformer' in current_path and 'encoderblock_0' in key: + # Map encoderblock_0 to encoderblock for non-scan mode + new_key = key.replace('encoderblock_0', 'encoderblock') + result[new_key] = value + return result + + adapted_params = adapt_nested_params(params_to_use) + params_to_use = adapted_params + print("šŸ”„ Adapted parameter names for non-scan mode") + print(f" Example mapping: encoderblock_0 -> encoderblock") + else: + params_to_use = params + print("āœ… JAX weights loaded successfully!") + + # Apply the params to the model using NNX state management + import flax.nnx as nnx + graphdef, model_state = nnx.split(model) + + # Debug: Let me check what the model actually expects first + print(f"šŸ” Checking what the model expects...") + try: + print(f"šŸ“‹ Model parameter structure:") + model_transformer_params = [] + for key in sorted(model_state.keys()): + if 'Transformer' in key: + model_transformer_params.append(key) + print(f" {key}: shape {getattr(model_state[key], 'shape', 'no shape')}") + + if not model_transformer_params: + print(" No Transformer parameters found in model! Let's see all keys:") + for key in sorted(model_state.keys())[:20]: # Show first 20 keys + print(f" {key}") + + # Let's also check if PaliGemma has nested structure in model + if 'PaliGemma' in model_state: + print(" Checking PaliGemma structure in model:") + paligemma_state = model_state['PaliGemma'] + if hasattr(paligemma_state, 'keys'): + for subkey in sorted(paligemma_state.keys()): + print(f" PaliGemma.{subkey}") + if hasattr(paligemma_state[subkey], 'keys'): + for subsubkey in sorted(paligemma_state[subkey].keys()): + print(f" PaliGemma.{subkey}.{subsubkey}") + if hasattr(paligemma_state[subkey][subsubkey], 'keys'): + for subsubsubkey in sorted(paligemma_state[subkey][subsubkey].keys()): + if 'Transformer' in subsubsubkey: + print(f" PaliGemma.{subkey}.{subsubkey}.{subsubsubkey}") + except Exception as e: + print(f" Could not inspect model parameters: {e}") + + # Now try to load parameters + try: + model_state.replace_by_pure_dict(params_to_use) + model = nnx.merge(graphdef, model_state) + print("āœ… Parameters loaded successfully!") + except Exception as e: + print(f"āŒ Parameter loading failed: {e}") + print("šŸ”„ Continuing with random initialization...") + model = nnx.merge(graphdef, model_state) + # except Exception as e: + # print(f"āŒ Failed to load JAX weights: {e}") + # print("Continuing with random initialization...") + + # Create fixed example + example = create_fixed_example() + + # Convert to JAX arrays + jax_example = {} + for key, value in example.items(): + if key == "image": + jax_example[key] = {k: jnp.array(v) for k, v in value.items()} + elif key == "image_mask": + jax_example[key] = {k: jnp.array(v) for k, v in value.items()} + else: + jax_example[key] = jnp.array(value) + + # Convert noise and time to JAX arrays + noise_jax = jnp.array(noise) + time_jax = jnp.array(time) + + # Create observation + observation = _model.Observation.from_dict(jax_example) + actions = jax_example["actions"] + + print(f"Observation state shape: {observation.state.shape}") + print(f"Observation state dtype: {observation.state.dtype}") + print(f"Actions shape: {actions.shape}") + print(f"Actions dtype: {actions.dtype}") + print(f"Noise shape: {noise_jax.shape}, dtype: {noise_jax.dtype}") + print(f"Time shape: {time_jax.shape}, dtype: {time_jax.dtype}") + + # Test forward pass with fixed noise and time + # try: + # Use the modified compute_loss method that accepts external noise and time + losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) + print(f"JAX forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + mean_loss = jnp.mean(losses).item() + print(f"Mean loss: {mean_loss:.6f}") + return True, losses + # except Exception as e: + # print(f"JAX forward pass failed: {e}") + # return False, None + + +def compare_losses(pytorch_loss, jax_loss): + """Compare losses and compute relative differences.""" + if pytorch_loss is None or jax_loss is None: + return + + print("\n" + "=" * 70) + print("šŸ“Š LOSS COMPARISON") + print("=" * 70) + + # # Handle tensor inputs by computing mean if needed + # if hasattr(pytorch_loss, 'mean'): + # pytorch_mean = pytorch_loss.to(torch.float32).mean().item() + # pytorch_std = pytorch_loss.to(torch.float32).std().item() + # print(f"PyTorch loss tensor - Mean: {pytorch_mean:.8f}, Std: {pytorch_std:.8f}") + # print(f"PyTorch loss shape: {pytorch_loss.shape}") + # else: + # pytorch_mean = float(pytorch_loss) + # pytorch_std = 0.0 + # print(f"PyTorch loss scalar: {pytorch_mean:.8f}") + + # if hasattr(jax_loss, 'mean'): + # jax_mean = jax_loss.mean().item() + # jax_std = jax_loss.std().item() + # print(f"JAX loss tensor - Mean: {jax_mean:.8f}, Std: {jax_std:.8f}") + # print(f"JAX loss shape: {jax_loss.shape}") + # else: + # jax_mean = float(jax_loss) + # jax_std = 0.0 + # print(f"JAX loss scalar: {jax_mean:.8f}") + + + + # Additional tensor analysis if both are tensors + pytorch_loss = pytorch_loss.to(torch.float32) + jax_loss = jax_loss.astype(jnp.float32) + if hasattr(pytorch_loss, 'shape') and hasattr(jax_loss, 'shape'): + print(f"\nšŸ“ Tensor Analysis:") + + # Check if shapes match + if pytorch_loss.shape == jax_loss.shape: + print(f"āœ… Tensor shapes match: {pytorch_loss.shape}") + + # Element-wise comparison + if hasattr(pytorch_loss, 'flatten') and hasattr(jax_loss, 'flatten'): + # Convert to numpy for element-wise analysis + try: + pytorch_flat = pytorch_loss.detach().cpu().numpy().flatten() + jax_flat = jax_loss.flatten() + + # Element-wise differences + element_diff = np.abs(pytorch_flat - jax_flat) + print(f"element_diff[0]: {element_diff[0:2048*816:2048]}") + max_element_diff = np.max(element_diff) + mean_element_diff = np.mean(element_diff) + + print(f" Max element-wise difference: {max_element_diff:.8f}") + print(f" Mean element-wise difference: {mean_element_diff:.8f}") + + # Element-wise relative differences + # Avoid division by zero by adding small epsilon + epsilon = 1e-12 + pytorch_flat_safe = pytorch_flat + epsilon + jax_flat_safe = jax_flat + epsilon + + # Compute relative differences for each element + rel_diff_pytorch_elements = (element_diff / np.abs(pytorch_flat_safe)) * 100 + rel_diff_jax_elements = (element_diff / np.abs(jax_flat_safe)) * 100 + + # Compute mean of relative differences + mean_rel_diff_pytorch = np.mean(rel_diff_pytorch_elements) + mean_rel_diff_jax = np.mean(rel_diff_jax_elements) + + print(f" Mean relative difference (w.r.t. PyTorch elements): {mean_rel_diff_pytorch:.4f}%") + print(f" Mean relative difference (w.r.t. JAX elements): {mean_rel_diff_jax:.4f}%") + + # Count elements with significant differences + significant_threshold = 1e-4 + significant_count = np.sum(element_diff > significant_threshold) + total_elements = len(element_diff) + significant_percentage = (significant_count / total_elements) * 100 + + print(f" Elements with diff > {significant_threshold}: {significant_count}/{total_elements} ({significant_percentage:.2f}%)") + + # Additional relative difference analysis + significant_rel_threshold = 1.0 # 1% + significant_rel_count_pytorch = np.sum(rel_diff_pytorch_elements > significant_rel_threshold) + significant_rel_count_jax = np.sum(rel_diff_jax_elements > significant_rel_threshold) + + print(f" Elements with rel diff > {significant_rel_threshold}% (w.r.t. PyTorch): {significant_rel_count_pytorch}/{total_elements} ({(significant_rel_count_pytorch/total_elements)*100:.2f}%)") + print(f" Elements with rel diff > {significant_rel_threshold}% (w.r.t. JAX): {significant_rel_count_jax}/{total_elements} ({(significant_rel_count_jax/total_elements)*100:.2f}%)") + + except Exception as e: + print(f" āš ļø Could not perform element-wise analysis: {e}") + else: + print(f"āŒ Tensor shapes don't match: PyTorch {pytorch_loss.shape} vs JAX {jax_loss.shape}") + + +def main(): + """Main function to test both implementations.""" + setup_logging() + + print("šŸš€ Testing Single Example Training for JAX vs PyTorch Comparison") + print("=" * 70) + print("šŸ“ Loading pre-trained weights for both models...") + print("šŸŽÆ Using fixed noise and time values for deterministic comparison...") + print("šŸ”§ Debug mode: JAX model will use only 1 encoder layer for faster debugging...") + + # Generate fixed noise and time + noise, time = create_fixed_noise_and_time( + batch_size=1, + action_horizon=10, + action_dim=32 + ) + + # Test PyTorch + pytorch_success, pytorch_losses = test_pytorch_single_example(noise, time) + torch.cuda.empty_cache() + + # Test JAX + jax_success, jax_losses = test_jax_single_example(noise, time, debug_single_layer=False) + + # Compare losses + if pytorch_success and jax_success: + compare_losses(pytorch_losses, jax_losses) + + # Summary + print("\n" + "=" * 70) + print("šŸ“Š SUMMARY") + print("=" * 70) + + if pytorch_success and jax_success: + print("āœ… Both JAX and PyTorch implementations work on the single example!") + print("šŸ” Loss comparison completed above.") + elif pytorch_success: + print("āŒ PyTorch works but JAX failed. Check JAX implementation.") + elif jax_success: + print("āŒ JAX works but PyTorch failed. Check PyTorch implementation.") + else: + print("āŒ Both implementations failed. Check the error messages above.") + + print("\nšŸ’” Next steps:") + print("1. Run this script to verify both implementations work") + print("2. Analyze the loss comparison results above") + print("3. If losses differ significantly, investigate the differences") + print("4. Check if the noise and time handling is consistent between implementations") + print("5. Use the same example in full training runs") + + +if __name__ == "__main__": + main() diff --git a/src/openpi/models/pi0.py b/src/openpi/models/pi0.py index ae7c459..c84f5c1 100644 --- a/src/openpi/models/pi0.py +++ b/src/openpi/models/pi0.py @@ -187,14 +187,24 @@ def embed_suffix( @override def compute_loss( - self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False + self, + rng: at.KeyArrayLike, + observation: _model.Observation, + actions: _model.Actions, + *, + train: bool = False, + noise: at.Float[at.Array, "*b ah ad"] | None = None, + time: at.Float[at.Array, "*b"] | None = None ) -> at.Float[at.Array, "*b ah"]: preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) - observation = _model.preprocess_observation(preprocess_rng, observation, train=train) + #observation = _model.preprocess_observation(preprocess_rng, observation, train=train) batch_shape = actions.shape[:-2] - noise = jax.random.normal(noise_rng, actions.shape) - time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 + # Use provided noise and time if available, otherwise generate them + if noise is None: + noise = jax.random.normal(noise_rng, actions.shape) + if time is None: + time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 time_expanded = time[..., None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions @@ -211,7 +221,7 @@ def compute_loss( ) v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) - return jnp.mean(jnp.square(v_t - u_t), axis=-1) + return jnp.square(v_t - u_t) @override def sample_actions( diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 6c6de37..9272a34 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -2,10 +2,38 @@ import torch from torch import nn from transformers import GemmaForCausalLM, PaliGemmaForConditionalGeneration +from transformers.models.gemma import modeling_gemma from transformers.models.auto import CONFIG_MAPPING +# TODO: compare this rope vs gemma rope +def apply_rope(x, positions, max_wavelength=10_000): + """ + Applies RoPE positions [B, L] to x [B, L, H, D]. + """ + d_half = x.shape[-1] // 2 + device = x.device + dtype = x.dtype + x = x.to(torch.float32) + + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + timescale = max_wavelength**freq_exponents + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + + radians = radians[..., None, :] + + sin = torch.sin(radians) # .to(dtype=dtype) + cos = torch.cos(radians) # .to(dtype=dtype) + + x1, x2 = x.split(d_half, dim=-1) + res = torch.empty_like(x) + res[..., :d_half] = x1 * cos - x2 * sin + res[..., d_half:] = x2 * cos + x1 * sin + + return res.to(dtype) + + class PaliGemmaWithExpertModel(nn.Module): def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): super().__init__() @@ -80,7 +108,7 @@ def forward( use_cache: bool | None = None, adarms_cond: list[torch.Tensor] | None = None, ): - if inputs_embeds[0] is not None: + if inputs_embeds[1] is None: prefix_output = self.paligemma.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, @@ -91,11 +119,8 @@ def forward( ) prefix_past_key_values = prefix_output.past_key_values prefix_output = prefix_output.last_hidden_state - else: - prefix_output = None - prefix_past_key_values = None - - if inputs_embeds[1] is not None: + suffix_output = None + if inputs_embeds[0] is None: suffix_output = self.gemma_expert.model.forward( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, @@ -105,7 +130,92 @@ def forward( adarms_cond=adarms_cond[1] if adarms_cond is not None else None, ) suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None else: - suffix_output = None + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + for layer_idx in range(num_layers): + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) + # hidden_states = hidden_states.to(dtype=torch.bfloat16) + # if gate is not None: + # gate = gate.to(dtype=torch.bfloat16) + gates.append(gate) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # B,L,H,D with L sequence length, H number of heads, D head dim + # concatenate on the number of embeddings/tokens + query_states = torch.cat(query_states, dim=1) + key_states = torch.cat(key_states, dim=1) + value_states = torch.cat(value_states, dim=1) + + query_states = apply_rope(query_states, position_ids) + key_states = apply_rope(key_states, position_ids) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) + # cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + # query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) + + batch_size = query_states.shape[0] + scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling + att_output, _ = modeling_gemma.eager_attention_forward( + self.paligemma.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling + ) + #att_output = att_output.to(dtype=torch.bfloat16) + att_output = att_output.reshape(batch_size, -1, 1 * 8 * layer.self_attn.head_dim) + + + # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) + outputs_embeds = [] + start = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end = start + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start:end]) + + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + out_emb = out_emb.to(dtype=torch.bfloat16) + + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) + outputs_embeds.append(out_emb) + start = end + inputs_embeds = outputs_embeds + + # final norm + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None return [prefix_output, suffix_output], prefix_past_key_values \ No newline at end of file diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 9a6c0ea..89bc0dd 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -230,8 +230,15 @@ def embed_suffix(self, state, noisy_actions, timestep): return embs, pad_masks, att_masks, adarms_cond - def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None) -> Tensor: + def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + # observation = _model.preprocess_observation_pytorch(observation, train=True) + images = list(observation.images.values()) + img_masks = list(observation.image_masks.values()) + lang_tokens = observation.tokenized_prompt + lang_masks = observation.tokenized_prompt_mask + state = observation.state + if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -244,6 +251,7 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) @@ -253,6 +261,7 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no # Add head dimension to attention mask: [B, seq_len, seq_len] -> [B, 1, seq_len, seq_len] att_2d_masks_4d = att_2d_masks[:, None, :, :] + att_2d_masks_4d = torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) (_, suffix_out), _ = self.paligemma_with_expert.forward( attention_mask=att_2d_masks_4d, @@ -260,7 +269,6 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no past_key_values=None, inputs_embeds=[prefix_embs, suffix_embs], use_cache=False, - fill_kv_cache=False, adarms_cond=[None, adarms_cond] ) suffix_out = suffix_out[:, -self.config.action_horizon :] @@ -268,12 +276,15 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no v_t = self.action_out_proj(suffix_out) - losses = F.mse_loss(u_t, v_t, reduction="none") + #losses = F.mse_loss(u_t, v_t, reduction="none") + + losses = torch.square(v_t - u_t) return losses @torch.no_grad() def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + # observation = _model.preprocess_observation(observation, train=False) bsize = observation.state.shape[0] if noise is None: actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index a4caa3c..723d212 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -724,7 +724,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=256, + batch_size=1, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -734,10 +734,30 @@ def __post_init__(self) -> None: optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader=weight_loaders.CheckpointWeightLoader( - "gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/params" + "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" ), num_train_steps=30_000, ), + TrainConfig( + name="pi05_libero_pytorch", + model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False), + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=False, + ), + batch_size=1, + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=10_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), + ema_decay=0.999, + weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", + num_train_steps=30_000, + ), # # Fine-tuning Aloha configs. # diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index 1f97f27..8355336 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -228,6 +228,7 @@ def create_data_loader( ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training.""" data_config = config.data.create(config.assets_dirs, config.model) + print(f"data_config: {data_config}") if data_config.rlds_data_dir is not None: return create_rlds_data_loader( diff --git a/uv.lock b/uv.lock index 1ad8e94..13bd8fc 100644 --- a/uv.lock +++ b/uv.lock @@ -1771,24 +1771,33 @@ wheels = [ [package.optional-dependencies] with-cuda = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", version = "12.9.0.13", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", version = "12.9.19", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "nvidia-cuda-nvcc-cu12" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-runtime-cu12", version = "12.9.37", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, { name = "nvidia-cudnn-cu12", version = "9.10.1.4", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cudnn-cu12", version = "9.10.2.21", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.3.83", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cufft-cu12", version = "11.4.0.6", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusolver-cu12", version = "11.7.3.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", version = "11.7.4.40", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", version = "12.5.9.5", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, { name = "nvidia-nccl-cu12", version = "2.26.5", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nccl-cu12", version = "2.27.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", version = "12.9.41", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] @@ -2608,22 +2617,31 @@ name = "nvidia-cublas-cu12" version = "12.6.4.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322 }, - { url = "https://files.pythonhosted.org/packages/97/0d/f1f0cadbf69d5b9ef2e4f744c9466cb0a850741d08350736dfdb4aa89569/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668", size = 390794615 }, { url = "https://files.pythonhosted.org/packages/84/f7/985e9bdbe3e0ac9298fcc8cfa51a392862a46a0ffaccbbd56939b62a9c83/nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8", size = 434535301 }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124 }, + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921 }, +] + [[package]] name = "nvidia-cublas-cu12" version = "12.9.0.13" @@ -2646,24 +2664,31 @@ name = "nvidia-cuda-cupti-cu12" version = "12.6.80" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/8b/2f6230cb715646c3a9425636e513227ce5c93c4d65823a734f4bb86d43c3/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc", size = 8236764 }, - { url = "https://files.pythonhosted.org/packages/25/0f/acb326ac8fd26e13c799e0b4f3b2751543e1834f04d62e729485872198d4/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4", size = 8236756 }, - { url = "https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132", size = 8917980 }, - { url = "https://files.pythonhosted.org/packages/a5/24/120ee57b218d9952c379d1e026c4479c9ece9997a4fb46303611ee48f038/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73", size = 8917972 }, { url = "https://files.pythonhosted.org/packages/1c/81/7796f096afaf726796b1b648f3bc80cafc61fe7f77f44a483c89e6c5ef34/nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a", size = 5724175 }, ] +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318 }, + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621 }, +] + [[package]] name = "nvidia-cuda-cupti-cu12" version = "12.9.19" @@ -2693,10 +2718,10 @@ wheels = [ [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.6.77" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380 }, + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029 }, ] [[package]] @@ -2704,24 +2729,31 @@ name = "nvidia-cuda-runtime-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/ea/590b2ac00d772a8abd1c387a92b46486d2679ca6622fd25c18ff76265663/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd", size = 908052 }, - { url = "https://files.pythonhosted.org/packages/b7/3d/159023799677126e20c8fd580cca09eeb28d5c5a624adc7f793b9aa8bbfa/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e", size = 908040 }, - { url = "https://files.pythonhosted.org/packages/e1/23/e717c5ac26d26cf39a27fbc076240fad2e3b817e5889d671b67f4f9f49c5/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7", size = 897690 }, - { url = "https://files.pythonhosted.org/packages/f0/62/65c05e161eeddbafeca24dc461f47de550d9fa8a7e04eb213e32b55cfd99/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8", size = 897678 }, { url = "https://files.pythonhosted.org/packages/fa/76/4c80fa138333cc975743fd0687a745fccb30d167f906f13c1c7f9a85e5ea/nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f", size = 891773 }, ] +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265 }, + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765 }, +] + [[package]] name = "nvidia-cuda-runtime-cu12" version = "12.9.37" @@ -2744,22 +2776,17 @@ name = "nvidia-cudnn-cu12" version = "9.5.1.17" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/99/93/a201a12d3ec1caa8c6ac34c1c2f9eeb696b886f0c36ff23c638b46603bd0/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def", size = 570523509 }, - { url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386 }, { url = "https://files.pythonhosted.org/packages/b6/b2/3f60d15f037fa5419d9d7f788b100ef33ea913ae5315c87ca6d6fa606c35/nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8", size = 565440743 }, ] @@ -2783,32 +2810,59 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/ec/79464a7371a028d1f443b8516b55cb2f70bb91bd3b2f2a831d707c003ccf/nvidia_cudnn_cu12-9.10.1.4-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:df73c4dab84df2c54f0a40e6427cde26e8d80feeffef02d749ee42d7da3c8204", size = 706752133 }, ] +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878 }, + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467 }, +] + [[package]] name = "nvidia-cufft-cu12" version = "11.3.0.4" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/37/c50d2b2f2c07e146776389e3080f4faf70bcc4fa6e19d65bb54ca174ebc3/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6", size = 200164144 }, - { url = "https://files.pythonhosted.org/packages/ce/f5/188566814b7339e893f8d210d3a5332352b1409815908dad6a363dcceac1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8510990de9f96c803a051822618d42bf6cb8f069ff3f48d93a8486efdacb48fb", size = 200164135 }, - { url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632 }, - { url = "https://files.pythonhosted.org/packages/60/de/99ec247a07ea40c969d904fc14f3a356b3e2a704121675b75c366b694ee1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca", size = 200221622 }, { url = "https://files.pythonhosted.org/packages/b4/38/36fd800cec8f6e89b7c1576edaaf8076e69ec631644cdbc1b5f2e2b5a9df/nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464", size = 199356881 }, ] +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211 }, + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695 }, +] + [[package]] name = "nvidia-cufft-cu12" version = "11.4.0.6" @@ -2831,19 +2885,18 @@ wheels = [ [[package]] name = "nvidia-cufile-cu12" -version = "1.11.1.6" +version = "1.13.1.3" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/66/cc9876340ac68ae71b15c743ddb13f8b30d5244af344ec8322b449e35426/nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159", size = 1142103 }, + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834 }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.7.77" +version = "10.3.9.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010 }, - { url = "https://files.pythonhosted.org/packages/4a/aa/2c7ff0b5ee02eaef890c0ce7d4f74bc30901871c5e45dee1ae6d0083cd80/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117", size = 56279000 }, + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976 }, ] [[package]] @@ -2851,29 +2904,41 @@ name = "nvidia-cusolver-cu12" version = "11.7.1.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/93/17/dbe1aa865e4fdc7b6d4d0dd308fdd5aaab60f939abfc0ea1954eac4fb113/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0", size = 157833628 }, - { url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790 }, - { url = "https://files.pythonhosted.org/packages/9f/81/baba53585da791d043c10084cf9553e074548408e04ae884cfe9193bd484/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6", size = 158229780 }, - { url = "https://files.pythonhosted.org/packages/7c/5f/07d0ba3b7f19be5a5ec32a8679fc9384cfd9fc6c869825e93be9f28d6690/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e", size = 157833630 }, { url = "https://files.pythonhosted.org/packages/d4/53/fff50a0808df7113d77e3bbc7c2b7eaed6f57d5eb80fbe93ead2aea1e09a/nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7", size = 149287877 }, ] +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841 }, + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905 }, +] + [[package]] name = "nvidia-cusolver-cu12" version = "11.7.4.40" @@ -2901,27 +2966,37 @@ name = "nvidia-cusparse-cu12" version = "12.5.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/eb/6681efd0aa7df96b4f8067b3ce7246833dd36830bb4cec8896182773db7d/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887", size = 216451147 }, - { url = "https://files.pythonhosted.org/packages/d3/56/3af21e43014eb40134dea004e8d0f1ef19d9596a39e4d497d5a7de01669f/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7aa32fa5470cf754f72d1116c7cbc300b4e638d3ae5304cfa4a638a5b87161b1", size = 216451135 }, - { url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367 }, - { url = "https://files.pythonhosted.org/packages/43/ac/64c4316ba163e8217a99680c7605f779accffc6a4bcd0c778c12948d3707/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f", size = 216561357 }, { url = "https://files.pythonhosted.org/packages/45/ef/876ad8e4260e1128e6d4aac803d9d51baf3791ebdb4a9b8d9b8db032b4b0/nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20", size = 213712630 }, ] +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129 }, + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466 }, +] + [[package]] name = "nvidia-cusparse-cu12" version = "12.5.9.5" @@ -2944,10 +3019,10 @@ wheels = [ [[package]] name = "nvidia-cusparselt-cu12" -version = "0.6.3" +version = "0.7.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796 }, + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691 }, ] [[package]] @@ -2964,20 +3039,13 @@ name = "nvidia-nccl-cu12" version = "2.26.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/5b/ca2f213f637305633814ae8c36b153220e40a07ea001966dcd87391f3acb/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522", size = 291671495 }, - { url = "https://files.pythonhosted.org/packages/67/ca/f42388aed0fddd64ade7493dbba36e1f534d4e6fdbdd355c6a90030ae028/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6", size = 201319755 }, -] [[package]] name = "nvidia-nccl-cu12" @@ -2996,27 +3064,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/fb/ec4ac065d9b0d56f72eaf1d9b0df601e33da28197b32ca351dc05b342611/nvidia_nccl_cu12-2.26.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea5ed3e053c735f16809bee7111deac62ac35b10128a8c102960a0462ce16cbe", size = 318069637 }, ] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/7b/8354b784cf73b0ba51e566b4baba3ddd44fe8288a3d39ef1e06cd5417226/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9ddf1a245abc36c550870f26d537a9b6087fb2e2e3d6e0ef03374c6fd19d984f", size = 322397768 }, + { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134 }, +] + [[package]] name = "nvidia-nvjitlink-cu12" version = "12.6.85" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971 }, - { url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338 }, { url = "https://files.pythonhosted.org/packages/89/76/93c1467b1387387440a4d25102d86b7794535449b689f8e2dc22c1c8ff7f/nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c", size = 161908572 }, ] +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836 }, + { url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204 }, +] + [[package]] name = "nvidia-nvjitlink-cu12" version = "12.9.41" @@ -3036,11 +3127,10 @@ wheels = [ [[package]] name = "nvidia-nvtx-cu12" -version = "12.6.77" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2", size = 89276 }, - { url = "https://files.pythonhosted.org/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265 }, + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954 }, ] [[package]] @@ -3180,7 +3270,7 @@ requires-dist = [ { name = "polars", specifier = ">=1.30.0" }, { name = "rich", specifier = ">=14.0.0" }, { name = "sentencepiece", specifier = ">=0.2.0" }, - { name = "torch", specifier = ">=2.7.0" }, + { name = "torch", specifier = ">=2.7.1" }, { name = "tqdm-loggable", specifier = ">=0.2" }, { name = "transformers", specifier = "==4.53.2" }, { name = "treescope", specifier = ">=0.1.7" }, @@ -4852,26 +4942,26 @@ wheels = [ [[package]] name = "torch" -version = "2.7.0" +version = "2.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", version = "9.10.2.21", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.3.83", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", version = "11.7.3.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", version = "2.27.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, @@ -4879,22 +4969,22 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/40/da/7378d16cc636697f2a94f791cb496939b60fb8580ddbbef22367db2c2274/torch-2.7.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2b7813e904757b125faf1a9a3154e1d50381d539ced34da1992f52440567c156", size = 99159397 }, - { url = "https://files.pythonhosted.org/packages/0e/6b/87fcddd34df9f53880fa1f0c23af7b6b96c935856473faf3914323588c40/torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fd5cfbb4c3bbadd57ad1b27d56a28008f8d8753733411a140fcfb84d7f933a25", size = 865183681 }, - { url = "https://files.pythonhosted.org/packages/13/85/6c1092d4b06c3db1ed23d4106488750917156af0b24ab0a2d9951830b0e9/torch-2.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:58df8d5c2eeb81305760282b5069ea4442791a6bbf0c74d9069b7b3304ff8a37", size = 212520100 }, - { url = "https://files.pythonhosted.org/packages/aa/3f/85b56f7e2abcfa558c5fbf7b11eb02d78a4a63e6aeee2bbae3bb552abea5/torch-2.7.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:0a8d43caa342b9986101ec5feb5bbf1d86570b5caa01e9cb426378311258fdde", size = 68569377 }, - { url = "https://files.pythonhosted.org/packages/aa/5e/ac759f4c0ab7c01feffa777bd68b43d2ac61560a9770eeac074b450f81d4/torch-2.7.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:36a6368c7ace41ad1c0f69f18056020b6a5ca47bedaca9a2f3b578f5a104c26c", size = 99013250 }, - { url = "https://files.pythonhosted.org/packages/9c/58/2d245b6f1ef61cf11dfc4aceeaacbb40fea706ccebac3f863890c720ab73/torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:15aab3e31c16feb12ae0a88dba3434a458874636f360c567caa6a91f6bfba481", size = 865042157 }, - { url = "https://files.pythonhosted.org/packages/44/80/b353c024e6b624cd9ce1d66dcb9d24e0294680f95b369f19280e241a0159/torch-2.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f56d4b2510934e072bab3ab8987e00e60e1262fb238176168f5e0c43a1320c6d", size = 212482262 }, - { url = "https://files.pythonhosted.org/packages/ee/8d/b2939e5254be932db1a34b2bd099070c509e8887e0c5a90c498a917e4032/torch-2.7.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:30b7688a87239a7de83f269333651d8e582afffce6f591fff08c046f7787296e", size = 68574294 }, - { url = "https://files.pythonhosted.org/packages/14/24/720ea9a66c29151b315ea6ba6f404650834af57a26b2a04af23ec246b2d5/torch-2.7.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:868ccdc11798535b5727509480cd1d86d74220cfdc42842c4617338c1109a205", size = 99015553 }, - { url = "https://files.pythonhosted.org/packages/4b/27/285a8cf12bd7cd71f9f211a968516b07dcffed3ef0be585c6e823675ab91/torch-2.7.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b52347118116cf3dff2ab5a3c3dd97c719eb924ac658ca2a7335652076df708", size = 865046389 }, - { url = "https://files.pythonhosted.org/packages/74/c8/2ab2b6eadc45554af8768ae99668c5a8a8552e2012c7238ded7e9e4395e1/torch-2.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:434cf3b378340efc87c758f250e884f34460624c0523fe5c9b518d205c91dd1b", size = 212490304 }, - { url = "https://files.pythonhosted.org/packages/28/fd/74ba6fde80e2b9eef4237fe668ffae302c76f0e4221759949a632ca13afa/torch-2.7.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:edad98dddd82220465b106506bb91ee5ce32bd075cddbcf2b443dfaa2cbd83bf", size = 68856166 }, - { url = "https://files.pythonhosted.org/packages/cb/b4/8df3f9fe6bdf59e56a0e538592c308d18638eb5f5dc4b08d02abb173c9f0/torch-2.7.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a885fc25afefb6e6eb18a7d1e8bfa01cc153e92271d980a49243b250d5ab6d9", size = 99091348 }, - { url = "https://files.pythonhosted.org/packages/9d/f5/0bd30e9da04c3036614aa1b935a9f7e505a9e4f1f731b15e165faf8a4c74/torch-2.7.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:176300ff5bc11a5f5b0784e40bde9e10a35c4ae9609beed96b4aeb46a27f5fae", size = 865104023 }, - { url = "https://files.pythonhosted.org/packages/d1/b7/2235d0c3012c596df1c8d39a3f4afc1ee1b6e318d469eda4c8bb68566448/torch-2.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d0ca446a93f474985d81dc866fcc8dccefb9460a29a456f79d99c29a78a66993", size = 212750916 }, - { url = "https://files.pythonhosted.org/packages/90/48/7e6477cf40d48cc0a61fa0d41ee9582b9a316b12772fcac17bc1a40178e7/torch-2.7.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:27f5007bdf45f7bb7af7f11d1828d5c2487e030690afb3d89a651fd7036a390e", size = 68575074 }, + { url = "https://files.pythonhosted.org/packages/8f/c4/3e7a3887eba14e815e614db70b3b529112d1513d9dae6f4d43e373360b7f/torch-2.8.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:220a06fd7af8b653c35d359dfe1aaf32f65aa85befa342629f716acb134b9710", size = 102073391 }, + { url = "https://files.pythonhosted.org/packages/5a/63/4fdc45a0304536e75a5e1b1bbfb1b56dd0e2743c48ee83ca729f7ce44162/torch-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c12fa219f51a933d5f80eeb3a7a5d0cbe9168c0a14bbb4055f1979431660879b", size = 888063640 }, + { url = "https://files.pythonhosted.org/packages/84/57/2f64161769610cf6b1c5ed782bd8a780e18a3c9d48931319f2887fa9d0b1/torch-2.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8c7ef765e27551b2fbfc0f41bcf270e1292d9bf79f8e0724848b1682be6e80aa", size = 241366752 }, + { url = "https://files.pythonhosted.org/packages/a4/5e/05a5c46085d9b97e928f3f037081d3d2b87fb4b4195030fc099aaec5effc/torch-2.8.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:5ae0524688fb6707c57a530c2325e13bb0090b745ba7b4a2cd6a3ce262572916", size = 73621174 }, + { url = "https://files.pythonhosted.org/packages/49/0c/2fd4df0d83a495bb5e54dca4474c4ec5f9c62db185421563deeb5dabf609/torch-2.8.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e2fab4153768d433f8ed9279c8133a114a034a61e77a3a104dcdf54388838705", size = 101906089 }, + { url = "https://files.pythonhosted.org/packages/99/a8/6acf48d48838fb8fe480597d98a0668c2beb02ee4755cc136de92a0a956f/torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2aca0939fb7e4d842561febbd4ffda67a8e958ff725c1c27e244e85e982173c", size = 887913624 }, + { url = "https://files.pythonhosted.org/packages/af/8a/5c87f08e3abd825c7dfecef5a0f1d9aa5df5dd0e3fd1fa2f490a8e512402/torch-2.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f4ac52f0130275d7517b03a33d2493bab3693c83dcfadf4f81688ea82147d2e", size = 241326087 }, + { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478 }, + { url = "https://files.pythonhosted.org/packages/10/4e/469ced5a0603245d6a19a556e9053300033f9c5baccf43a3d25ba73e189e/torch-2.8.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b2f96814e0345f5a5aed9bf9734efa913678ed19caf6dc2cddb7930672d6128", size = 101936856 }, + { url = "https://files.pythonhosted.org/packages/16/82/3948e54c01b2109238357c6f86242e6ecbf0c63a1af46906772902f82057/torch-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:65616ca8ec6f43245e1f5f296603e33923f4c30f93d65e103d9e50c25b35150b", size = 887922844 }, + { url = "https://files.pythonhosted.org/packages/e3/54/941ea0a860f2717d86a811adf0c2cd01b3983bdd460d0803053c4e0b8649/torch-2.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:659df54119ae03e83a800addc125856effda88b016dfc54d9f65215c3975be16", size = 241330968 }, + { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128 }, + { url = "https://files.pythonhosted.org/packages/15/0e/8a800e093b7f7430dbaefa80075aee9158ec22e4c4fc3c1a66e4fb96cb4f/torch-2.8.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:83c13411a26fac3d101fe8035a6b0476ae606deb8688e904e796a3534c197def", size = 102020139 }, + { url = "https://files.pythonhosted.org/packages/4a/15/5e488ca0bc6162c86a33b58642bc577c84ded17c7b72d97e49b5833e2d73/torch-2.8.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8f0a9d617a66509ded240add3754e462430a6c1fc5589f86c17b433dd808f97a", size = 887990692 }, + { url = "https://files.pythonhosted.org/packages/b4/a8/6a04e4b54472fc5dba7ca2341ab219e529f3c07b6941059fbf18dccac31f/torch-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a7242b86f42be98ac674b88a4988643b9bc6145437ec8f048fea23f72feb5eca", size = 241603453 }, + { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395 }, ] [[package]] @@ -4914,7 +5004,7 @@ wheels = [ [[package]] name = "torchvision" -version = "0.22.0" +version = "0.23.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, @@ -4922,22 +5012,22 @@ dependencies = [ { name = "torch" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/43/28bc858b022f6337326d75f4027d2073aad5432328f01ee1236d847f1b82/torchvision-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:191ea28321fc262d8aa1a7fe79c41ff2848864bf382f9f6ea45c41dde8313792", size = 1947828 }, - { url = "https://files.pythonhosted.org/packages/7e/71/ce9a303b94e64fe25d534593522ffc76848c4e64c11e4cbe9f6b8d537210/torchvision-0.22.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6c5620e10ffe388eb6f4744962106ed7cf1508d26e6fdfa0c10522d3249aea24", size = 2514016 }, - { url = "https://files.pythonhosted.org/packages/09/42/6908bff012a1dcc4fc515e52339652d7f488e208986542765c02ea775c2f/torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ce292701c77c64dd3935e3e31c722c3b8b176a75f76dc09b804342efc1db5494", size = 7447546 }, - { url = "https://files.pythonhosted.org/packages/e4/cf/8f9305cc0ea26badbbb3558ecae54c04a245429f03168f7fad502f8a5b25/torchvision-0.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:e4017b5685dbab4250df58084f07d95e677b2f3ed6c2e507a1afb8eb23b580ca", size = 1716472 }, - { url = "https://files.pythonhosted.org/packages/cb/ea/887d1d61cf4431a46280972de665f350af1898ce5006cd046326e5d0a2f2/torchvision-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31c3165418fe21c3d81fe3459e51077c2f948801b8933ed18169f54652796a0f", size = 1947826 }, - { url = "https://files.pythonhosted.org/packages/72/ef/21f8b6122e13ae045b8e49658029c695fd774cd21083b3fa5c3f9c5d3e35/torchvision-0.22.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8f116bc82e0c076e70ba7776e611ed392b9666aa443662e687808b08993d26af", size = 2514571 }, - { url = "https://files.pythonhosted.org/packages/7c/48/5f7617f6c60d135f86277c53f9d5682dfa4e66f4697f505f1530e8b69fb1/torchvision-0.22.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ce4dc334ebd508de2c534817c9388e928bc2500cf981906ae8d6e2ca3bf4727a", size = 7446522 }, - { url = "https://files.pythonhosted.org/packages/99/94/a015e93955f5d3a68689cc7c385a3cfcd2d62b84655d18b61f32fb04eb67/torchvision-0.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:24b8c9255c209ca419cc7174906da2791c8b557b75c23496663ec7d73b55bebf", size = 1716664 }, - { url = "https://files.pythonhosted.org/packages/e1/2a/9b34685599dcb341d12fc2730055155623db7a619d2415a8d31f17050952/torchvision-0.22.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ece17995857dd328485c9c027c0b20ffc52db232e30c84ff6c95ab77201112c5", size = 1947823 }, - { url = "https://files.pythonhosted.org/packages/77/77/88f64879483d66daf84f1d1c4d5c31ebb08e640411139042a258d5f7dbfe/torchvision-0.22.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:471c6dd75bb984c6ebe4f60322894a290bf3d4b195e769d80754f3689cd7f238", size = 2471592 }, - { url = "https://files.pythonhosted.org/packages/f7/82/2f813eaae7c1fae1f9d9e7829578f5a91f39ef48d6c1c588a8900533dd3d/torchvision-0.22.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2b839ac0610a38f56bef115ee5b9eaca5f9c2da3c3569a68cc62dbcc179c157f", size = 7446333 }, - { url = "https://files.pythonhosted.org/packages/58/19/ca7a4f8907a56351dfe6ae0a708f4e6b3569b5c61d282e3e7f61cf42a4ce/torchvision-0.22.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ada1c08b2f761443cd65b7c7b4aec9e2fc28f75b0d4e1b1ebc9d3953ebccc4d", size = 1716693 }, - { url = "https://files.pythonhosted.org/packages/6f/a7/f43e9c8d13118b4ffbaebea664c9338ab20fa115a908125afd2238ff16e7/torchvision-0.22.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:cdc96daa4658b47ce9384154c86ed1e70cba9d972a19f5de6e33f8f94a626790", size = 2137621 }, - { url = "https://files.pythonhosted.org/packages/6a/9a/2b59f5758ba7e3f23bc84e16947493bbce97392ec6d18efba7bdf0a3b10e/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:753d3c84eeadd5979a33b3b73a25ecd0aa4af44d6b45ed2c70d44f5e0ac68312", size = 2476555 }, - { url = "https://files.pythonhosted.org/packages/7d/40/a7bc2ab9b1e56d10a7fd9ae83191bb425fa308caa23d148f1c568006e02c/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b30e3ed29e4a61f7499bca50f57d8ebd23dfc52b14608efa17a534a55ee59a03", size = 7617924 }, - { url = "https://files.pythonhosted.org/packages/c1/7b/30d423bdb2546250d719d7821aaf9058cc093d165565b245b159c788a9dd/torchvision-0.22.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e5d680162694fac4c8a374954e261ddfb4eb0ce103287b0f693e4e9c579ef957", size = 1638621 }, + { url = "https://files.pythonhosted.org/packages/f0/d7/15d3d7bd8d0239211b21673d1bac7bc345a4ad904a8e25bb3fd8a9cf1fbc/torchvision-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:49aa20e21f0c2bd458c71d7b449776cbd5f16693dd5807195a820612b8a229b7", size = 1856884 }, + { url = "https://files.pythonhosted.org/packages/dd/14/7b44fe766b7d11e064c539d92a172fa9689a53b69029e24f2f1f51e7dc56/torchvision-0.23.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:01dc33ee24c79148aee7cdbcf34ae8a3c9da1674a591e781577b716d233b1fa6", size = 2395543 }, + { url = "https://files.pythonhosted.org/packages/79/9c/fcb09aff941c8147d9e6aa6c8f67412a05622b0c750bcf796be4c85a58d4/torchvision-0.23.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:35c27941831b653f5101edfe62c03d196c13f32139310519e8228f35eae0e96a", size = 8628388 }, + { url = "https://files.pythonhosted.org/packages/93/40/3415d890eb357b25a8e0a215d32365a88ecc75a283f75c4e919024b22d97/torchvision-0.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:09bfde260e7963a15b80c9e442faa9f021c7e7f877ac0a36ca6561b367185013", size = 1600741 }, + { url = "https://files.pythonhosted.org/packages/df/1d/0ea0b34bde92a86d42620f29baa6dcbb5c2fc85990316df5cb8f7abb8ea2/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0e2c04a91403e8dd3af9756c6a024a1d9c0ed9c0d592a8314ded8f4fe30d440", size = 1856885 }, + { url = "https://files.pythonhosted.org/packages/e2/00/2f6454decc0cd67158c7890364e446aad4b91797087a57a78e72e1a8f8bc/torchvision-0.23.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6dd7c4d329a0e03157803031bc856220c6155ef08c26d4f5bbac938acecf0948", size = 2396614 }, + { url = "https://files.pythonhosted.org/packages/e4/b5/3e580dcbc16f39a324f3dd71b90edbf02a42548ad44d2b4893cc92b1194b/torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4e7d31c43bc7cbecbb1a5652ac0106b436aa66e26437585fc2c4b2cf04d6014c", size = 8627108 }, + { url = "https://files.pythonhosted.org/packages/82/c1/c2fe6d61e110a8d0de2f94276899a2324a8f1e6aee559eb6b4629ab27466/torchvision-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:a2e45272abe7b8bf0d06c405e78521b5757be1bd0ed7e5cd78120f7fdd4cbf35", size = 1600723 }, + { url = "https://files.pythonhosted.org/packages/91/37/45a5b9407a7900f71d61b2b2f62db4b7c632debca397f205fdcacb502780/torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1c37e325e09a184b730c3ef51424f383ec5745378dc0eca244520aca29722600", size = 1856886 }, + { url = "https://files.pythonhosted.org/packages/ac/da/a06c60fc84fc849377cf035d3b3e9a1c896d52dbad493b963c0f1cdd74d0/torchvision-0.23.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2f7fd6c15f3697e80627b77934f77705f3bc0e98278b989b2655de01f6903e1d", size = 2353112 }, + { url = "https://files.pythonhosted.org/packages/a0/27/5ce65ba5c9d3b7d2ccdd79892ab86a2f87ac2ca6638f04bb0280321f1a9c/torchvision-0.23.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a76fafe113b2977be3a21bf78f115438c1f88631d7a87203acb3dd6ae55889e6", size = 8627658 }, + { url = "https://files.pythonhosted.org/packages/1f/e4/028a27b60aa578a2fa99d9d7334ff1871bb17008693ea055a2fdee96da0d/torchvision-0.23.0-cp313-cp313-win_amd64.whl", hash = "sha256:07d069cb29691ff566e3b7f11f20d91044f079e1dbdc9d72e0655899a9b06938", size = 1600749 }, + { url = "https://files.pythonhosted.org/packages/05/35/72f91ad9ac7c19a849dedf083d347dc1123f0adeb401f53974f84f1d04c8/torchvision-0.23.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2df618e1143805a7673aaf82cb5720dd9112d4e771983156aaf2ffff692eebf9", size = 2047192 }, + { url = "https://files.pythonhosted.org/packages/1d/9d/406cea60a9eb9882145bcd62a184ee61e823e8e1d550cdc3c3ea866a9445/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a3299d2b1d5a7aed2d3b6ffb69c672ca8830671967eb1cee1497bacd82fe47b", size = 2359295 }, + { url = "https://files.pythonhosted.org/packages/2b/f4/34662f71a70fa1e59de99772142f22257ca750de05ccb400b8d2e3809c1d/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:76bc4c0b63d5114aa81281390f8472a12a6a35ce9906e67ea6044e5af4cab60c", size = 8800474 }, + { url = "https://files.pythonhosted.org/packages/6e/f5/b5a2d841a8d228b5dbda6d524704408e19e7ca6b7bb0f24490e081da1fa1/torchvision-0.23.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b9e2dabf0da9c8aa9ea241afb63a8f3e98489e706b22ac3f30416a1be377153b", size = 1527667 }, ] [[package]] @@ -5039,16 +5129,16 @@ wheels = [ [[package]] name = "triton" -version = "3.3.0" +version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "setuptools", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/c5/4874a81131cc9e934d88377fbc9d24319ae1fb540f3333b4e9c696ebc607/triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3161a2bf073d6b22c4e2f33f951f3e5e3001462b2570e6df9cd57565bdec2984", size = 156528461 }, - { url = "https://files.pythonhosted.org/packages/11/53/ce18470914ab6cfbec9384ee565d23c4d1c55f0548160b1c7b33000b11fd/triton-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b68c778f6c4218403a6bd01be7484f6dc9e20fe2083d22dd8aef33e3b87a10a3", size = 156504509 }, - { url = "https://files.pythonhosted.org/packages/7d/74/4bf2702b65e93accaa20397b74da46fb7a0356452c1bb94dbabaf0582930/triton-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47bc87ad66fa4ef17968299acacecaab71ce40a238890acc6ad197c3abe2b8f1", size = 156516468 }, - { url = "https://files.pythonhosted.org/packages/0a/93/f28a696fa750b9b608baa236f8225dd3290e5aff27433b06143adc025961/triton-3.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce4700fc14032af1e049005ae94ba908e71cd6c2df682239aed08e49bc71b742", size = 156580729 }, + { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138 }, + { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068 }, + { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223 }, + { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780 }, ] [[package]] From 468e5f2ee345f38f0ab1b9fc6debe41d8d8a843a Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 09:37:48 -0700 Subject: [PATCH 02/32] add preprocess --- scripts/train_pytorch.py | 152 +++++++++++++++++++---- src/openpi/models/model.py | 147 ++++++++++++++++++++++ src/openpi/models/pi0.py | 2 +- src/openpi/models_pytorch/pi0_pytorch.py | 3 +- src/openpi/shared/image_tools.py | 79 ++++++++++++ src/openpi/training/config.py | 4 +- 6 files changed, 361 insertions(+), 26 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 1a46fe1..78bd389 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -144,6 +144,11 @@ def setup_ddp(): if use_ddp and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dist.init_process_group(backend=backend, init_method="env://") + + # Set up debugging environment variables for DDP issues + if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): @@ -214,6 +219,16 @@ def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Tuple[list[to return batch +def get_model_state_dict(model): + """Get state dict from model, handling DDP wrapper.""" + return model.module.state_dict() if isinstance(model, DDP) else model.state_dict() + + +def get_model_parameters(model): + """Get parameters from model, handling DDP wrapper.""" + return model.module.parameters() if isinstance(model, DDP) else model.parameters() + + def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_interval=None, ema_model=None): """Save a checkpoint with model state, optimizer state, EMA state, and metadata.""" if not is_main: @@ -228,7 +243,7 @@ def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_in os.makedirs(ckpt_dir, exist_ok=True) # Save model state - state_dict = (model.module if isinstance(model, DDP) else model).state_dict() + state_dict = get_model_state_dict(model) torch.save(state_dict, os.path.join(ckpt_dir, "pytorch_model.pt")) # Save optimizer state @@ -259,30 +274,30 @@ def load_checkpoint(model, optimizer, config, device, ema_model=None): for d in config.checkpoint_dir.iterdir(): if d.is_dir() and d.name.isdigit(): checkpoint_steps.append(int(d.name)) - + if not checkpoint_steps: raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir}") - + latest_step = max(checkpoint_steps) ckpt_dir = os.path.join(config.checkpoint_dir, f"{latest_step}") - + # Load model state model_state_dict = torch.load(os.path.join(ckpt_dir, "pytorch_model.pt"), map_location=device) (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) - + # Load optimizer state optimizer_state_dict = torch.load(os.path.join(ckpt_dir, "optimizer.pt"), map_location=device) optimizer.load_state_dict(optimizer_state_dict) - + # Load EMA state if available if ema_model is not None and os.path.exists(os.path.join(ckpt_dir, "ema_model.pt")): ema_state_dict = torch.load(os.path.join(ckpt_dir, "ema_model.pt"), map_location=device) ema_model.load_state_dict(ema_state_dict) logging.info(f"Loaded EMA state from checkpoint") - + # Load metadata metadata = torch.load(os.path.join(ckpt_dir, "metadata.pt"), map_location=device) - + logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") return metadata["global_step"] @@ -297,6 +312,57 @@ def get_latest_checkpoint_step(config): return max(checkpoint_steps) if checkpoint_steps else None +def debug_unused_parameters(model, device): + """Debug function to identify unused parameters in the model.""" + if isinstance(model, DDP): + model = model.module + + logging.info("Checking for potentially unused parameters...") + + # Get all parameter names and their indices + param_info = {} + idx = 0 + for name, param in model.named_parameters(): + if param.requires_grad: + param_info[idx] = name + idx += 1 + + logging.info(f"Total trainable parameters: {len(param_info)}") + + # Check which parameters have gradients after a forward pass + # This is a diagnostic function that can be called if needed + return param_info + + +def check_model_parameters(model, device): + """Check for unused parameters and provide debugging information.""" + if isinstance(model, DDP): + model = model.module + + total_params = 0 + used_params = 0 + + for name, param in model.named_parameters(): + total_params += param.numel() + if param.requires_grad: + used_params += param.numel() + + logging.info(f"Model parameters: {total_params:,} total, {used_params:,} trainable") + + # Check for parameters that might be unused + unused_params = [] + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is None: + unused_params.append(name) + + if unused_params: + logging.warning(f"Found {len(unused_params)} parameters that might be unused:") + for name in unused_params[:10]: # Show first 10 + logging.warning(f" - {name}") + if len(unused_params) > 10: + logging.warning(f" ... and {len(unused_params) - 10} more") + + def setup_memory_optimizations(model, device, enable_gradient_checkpointing=False): """Setup memory optimization techniques for the model.""" if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): @@ -410,12 +476,17 @@ def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, grad model_cfg = config.model model = PI0Pytorch(model_cfg).to(device) - + # Apply memory optimizations setup_memory_optimizations(model, device, enable_gradient_checkpointing) - + + # Check model parameters for debugging + if is_main: + check_model_parameters(model, device) + if use_ddp: - model = DDP(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=False) + # Enable unused parameter detection to handle cases where some parameters don't participate in loss + model = DDP(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=True) # Load weights from weight_loader if specified (for fine-tuning) if isinstance(config.weight_loader, str): @@ -445,10 +516,20 @@ def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, grad # Initialize EMA if specified in config ema_model = None if config.ema_decay is not None: - ema_model = PI0Pytorch(model_cfg).to(device) - ema_model.load_state_dict(model.state_dict()) - ema_model.eval() - logging.info(f"Initialized EMA with decay {config.ema_decay}") + try: + ema_model = PI0Pytorch(model_cfg).to(device) + + # Get the correct state dict from the main model + main_model_state_dict = get_model_state_dict(model) + + # Load the state dict into EMA model + ema_model.load_state_dict(main_model_state_dict) + ema_model.eval() + logging.info(f"Initialized EMA with decay {config.ema_decay}") + except Exception as e: + logging.error(f"Failed to initialize EMA model: {e}") + logging.error("Continuing without EMA...") + ema_model = None # Load checkpoint if resuming global_step = 0 @@ -501,9 +582,24 @@ def lr_schedule(step: int): # Forward pass with mixed precision observation = _model.Observation.from_dict(batch) - with torch.amp.autocast('cuda', enabled=mixed_precision and torch.cuda.is_available()): - losses = model(observation, actions) - loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation + try: + with torch.amp.autocast('cuda', enabled=mixed_precision and torch.cuda.is_available()): + losses = model(observation, actions) + # Ensure losses is a tensor and handle different return types + if isinstance(losses, (list, tuple)): + losses = torch.stack(losses) + elif not isinstance(losses, torch.Tensor): + losses = torch.tensor(losses, device=device, dtype=torch.float32) + + loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation + except RuntimeError as e: + if "Expected to have finished reduction" in str(e) or "did not receive grad" in str(e): + logging.error(f"DDP error on rank {dist.get_rank() if use_ddp else 0}: {e}") + logging.error("This usually indicates unused parameters in the model.") + logging.error("Try setting TORCH_DISTRIBUTED_DEBUG=DETAIL for more information.") + raise + else: + raise # Backward pass with gradient scaling scaler.scale(loss).backward() @@ -521,9 +617,15 @@ def lr_schedule(step: int): # Update EMA if enabled if ema_model is not None: - with torch.no_grad(): - for param, ema_param in zip(model.parameters(), ema_model.parameters()): - ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) + try: + with torch.no_grad(): + # Get parameters from the correct model structure + main_model_params = get_model_parameters(model) + for param, ema_param in zip(main_model_params, ema_model.parameters()): + ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) + except Exception as e: + logging.warning(f"Failed to update EMA model: {e}") + # Continue training without EMA update # Collect stats (only on accumulation steps) if (global_step + 1) % gradient_accumulation_steps == 0 and is_main: @@ -602,10 +704,16 @@ def main(): help="Maximum GPU memory usage in GB (default: None, auto-detect)") parser.add_argument("--enable_gradient_checkpointing", action="store_true", default=False, help="Enable gradient checkpointing for memory optimization") + parser.add_argument("--ddp_debug_level", type=str, default="INFO", choices=["INFO", "DETAIL", "OFF"], + help="DDP debugging level (default: INFO)") args, _ = parser.parse_known_args() - + # Handle mixed precision flag mixed_precision = args.mixed_precision and not args.no_mixed_precision + + # Set DDP debug level + if args.ddp_debug_level != "OFF": + os.environ["TORCH_DISTRIBUTED_DEBUG"] = args.ddp_debug_level train_loop(config, ckpt_save_interval=args.ckpt_save_interval, diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index d03c745..81028ae 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -207,6 +207,153 @@ def preprocess_observation( ) +def preprocess_observation_pytorch( + observation: Observation, + *, + train: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +) -> Observation: + """PyTorch version of preprocess_observation. Preprocesses observations with PyTorch tensors by performing + image resizing (if necessary) and filling in a default image mask (if necessary). + + Note: Image augmentation is not implemented for PyTorch tensors as augmax is JAX-specific. + """ + + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + for key in image_keys: + image = observation.images[key] + + # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats + # Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + image = image.permute(0, 2, 3, 1) + + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad_torch(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + image = image / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = image.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + start_h = torch.randint(0, max_h + 1, (1,)).item() + start_w = torch.randint(0, max_w + 1, (1,)).item() + image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] + + # Resize back to original size + image = torch.nn.functional.interpolate( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode='bilinear', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + angle = torch.rand(1).item() * 10 - 5 # Random angle between -5 and 5 degrees + if abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = torch.tensor(angle * torch.pi / 180.0, device=image.device) + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=image.device) + grid_y = torch.linspace(-1, 1, height, device=image.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + image = torch.nn.functional.grid_sample( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + brightness_factor = 0.7 + torch.rand(1).item() * 0.6 # Random factor between 0.7 and 1.3 + image = image * brightness_factor + + # Random contrast + contrast_factor = 0.6 + torch.rand(1).item() * 0.8 # Random factor between 0.6 and 1.4 + mean = image.mean(dim=[1, 2, 3], keepdim=True) + image = (image - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + saturation_factor = 0.5 + torch.rand(1).item() * 1.0 # Random factor between 0.5 and 1.5 + gray = image.mean(dim=-1, keepdim=True) + image = gray + (image - gray) * saturation_factor + + # Clamp values to [0, 1] + image = torch.clamp(image, 0, 1) + + # Back to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) + else: + out_masks[key] = observation.image_masks[key] + + return Observation( + images=out_images, + image_masks=out_masks, + state=observation.state, + tokenized_prompt=observation.tokenized_prompt, + tokenized_prompt_mask=observation.tokenized_prompt_mask, + token_ar_mask=observation.token_ar_mask, + token_loss_mask=observation.token_loss_mask, + ) + + @dataclasses.dataclass(frozen=True) class BaseModelConfig(abc.ABC): """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` diff --git a/src/openpi/models/pi0.py b/src/openpi/models/pi0.py index c84f5c1..7160983 100644 --- a/src/openpi/models/pi0.py +++ b/src/openpi/models/pi0.py @@ -197,7 +197,7 @@ def compute_loss( time: at.Float[at.Array, "*b"] | None = None ) -> at.Float[at.Array, "*b ah"]: preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) - #observation = _model.preprocess_observation(preprocess_rng, observation, train=train) + observation = _model.preprocess_observation(preprocess_rng, observation, train=train) batch_shape = actions.shape[:-2] # Use provided noise and time if available, otherwise generate them diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 89bc0dd..6c08f69 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -7,6 +7,7 @@ import openpi.models.gemma as _gemma from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel +import openpi.models.model as _model def get_safe_dtype(target_dtype, device_type): @@ -232,7 +233,7 @@ def embed_suffix(self, state, noisy_actions, timestep): def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - # observation = _model.preprocess_observation_pytorch(observation, train=True) + observation = _model.preprocess_observation_pytorch(observation, train=True) images = list(observation.images.values()) img_masks = list(observation.image_masks.values()) lang_tokens = observation.tokenized_prompt diff --git a/src/openpi/shared/image_tools.py b/src/openpi/shared/image_tools.py index 4d63e1c..78f0e17 100644 --- a/src/openpi/shared/image_tools.py +++ b/src/openpi/shared/image_tools.py @@ -2,6 +2,8 @@ import jax import jax.numpy as jnp +import torch +import torch.nn.functional as F import openpi.shared.array_typing as at @@ -48,3 +50,80 @@ def resize_with_pad( if not has_batch_dim: padded_images = padded_images[0] return padded_images + + +def resize_with_pad_torch( + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + # Convert to channels-first for torch operations + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode='constant', + value=constant_value + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + if batch_size == 1 and images.shape[0] == 1: + padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added + + return padded_images \ No newline at end of file diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 723d212..6235e32 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -724,7 +724,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=1, + batch_size=64, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -746,7 +746,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=1, + batch_size=64, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, From eb6ebb1ad5612227aee9359a1f51dad08a131bd0 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 12:28:55 -0700 Subject: [PATCH 03/32] fixes --- scripts/train_single_example.py | 48 +++++++++++++++------- src/openpi/models_pytorch/gemma_pytorch.py | 32 +++++++-------- src/openpi/models_pytorch/pi0_pytorch.py | 5 +-- 3 files changed, 52 insertions(+), 33 deletions(-) diff --git a/scripts/train_single_example.py b/scripts/train_single_example.py index 2a23d48..9418342 100644 --- a/scripts/train_single_example.py +++ b/scripts/train_single_example.py @@ -11,6 +11,7 @@ import jax.numpy as jnp import flax.nnx as nnx import flax +from unittest.mock import patch from openpi.models import model as _model from openpi.models.pi0_config import Pi0Config @@ -83,6 +84,16 @@ def create_fixed_noise_and_time(batch_size, action_horizon, action_dim): return noise, time +def mock_preprocess_observation(rng, observation, **kwargs): + """Mock function that returns observation unchanged to disable preprocessing.""" + return observation + + +def mock_preprocess_observation_pytorch(observation, **kwargs): + """Mock function that returns observation unchanged to disable preprocessing.""" + return observation + + def test_pytorch_single_example(noise, time): """Test PyTorch training on single example.""" print("\n=== Testing PyTorch on Single Example ===") @@ -136,13 +147,15 @@ def test_pytorch_single_example(noise, time): model.eval() with torch.no_grad(): #try: - losses = model(observation, actions, noise=noise_tensor, time=time_tensor) - print(f"PyTorch forward pass successful!") - print(f"Losses shape: {losses.shape}") - print(f"Losses dtype: {losses.dtype}") - # mean_loss = losses.to(torch.float32).mean().item() - # print(f"Mean loss: {mean_loss:.6f}") - return True, losses + # Use mock to disable preprocessing + with patch('openpi.models.model.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): + losses = model(observation, actions, noise=noise_tensor, time=time_tensor) + print(f"PyTorch forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + mean_loss = losses.to(torch.float32).mean().item() + print(f"Mean loss: {mean_loss:.6f}") + return True, losses # except Exception as e: # print(f"PyTorch forward pass failed: {e}") # return False, None @@ -384,13 +397,15 @@ def adapt_nested_params(params_dict, key_path=""): # Test forward pass with fixed noise and time # try: # Use the modified compute_loss method that accepts external noise and time - losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) - print(f"JAX forward pass successful!") - print(f"Losses shape: {losses.shape}") - print(f"Losses dtype: {losses.dtype}") - mean_loss = jnp.mean(losses).item() - print(f"Mean loss: {mean_loss:.6f}") - return True, losses + # Use mock to disable preprocessing + with patch('openpi.models.model.preprocess_observation', side_effect=mock_preprocess_observation): + losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) + print(f"JAX forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + mean_loss = jnp.mean(losses).item() + print(f"Mean loss: {mean_loss:.6f}") + return True, losses # except Exception as e: # print(f"JAX forward pass failed: {e}") # return False, None @@ -405,6 +420,9 @@ def compare_losses(pytorch_loss, jax_loss): print("šŸ“Š LOSS COMPARISON") print("=" * 70) + print(f"PyTorch loss: {pytorch_loss}") + print(f"JAX loss: {jax_loss}") + # # Handle tensor inputs by computing mean if needed # if hasattr(pytorch_loss, 'mean'): # pytorch_mean = pytorch_loss.to(torch.float32).mean().item() @@ -502,6 +520,7 @@ def main(): print("šŸ“ Loading pre-trained weights for both models...") print("šŸŽÆ Using fixed noise and time values for deterministic comparison...") print("šŸ”§ Debug mode: JAX model will use only 1 encoder layer for faster debugging...") + print("🚫 Preprocessing disabled: Image augmentations and resizing are bypassed for fair comparison...") # Generate fixed noise and time noise, time = create_fixed_noise_and_time( @@ -542,6 +561,7 @@ def main(): print("3. If losses differ significantly, investigate the differences") print("4. Check if the noise and time handling is consistent between implementations") print("5. Use the same example in full training runs") + print("6. Note: Preprocessing (image augmentations) is disabled for this comparison") if __name__ == "__main__": diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 9272a34..5a27cfa 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -106,7 +106,7 @@ def forward( past_key_values: list[torch.FloatTensor] | Cache | None = None, inputs_embeds: list[torch.FloatTensor] = None, use_cache: bool | None = None, - adarms_cond: list[torch.Tensor] | None = None, + adarms_cond: list[torch.Tensor] = [None, None], ): if inputs_embeds[1] is None: prefix_output = self.paligemma.language_model.forward( @@ -120,7 +120,7 @@ def forward( prefix_past_key_values = prefix_output.past_key_values prefix_output = prefix_output.last_hidden_state suffix_output = None - if inputs_embeds[0] is None: + elif inputs_embeds[0] is None: suffix_output = self.gemma_expert.model.forward( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, @@ -150,9 +150,9 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states.append(query_state) key_states.append(key_state) @@ -160,20 +160,20 @@ def forward( # B,L,H,D with L sequence length, H number of heads, D head dim # concatenate on the number of embeddings/tokens - query_states = torch.cat(query_states, dim=1) - key_states = torch.cat(key_states, dim=1) - value_states = torch.cat(value_states, dim=1) + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) - query_states = apply_rope(query_states, position_ids) - key_states = apply_rope(key_states, position_ids) + # query_states = apply_rope(query_states, position_ids) + # key_states = apply_rope(key_states, position_ids) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + # query_states = query_states.transpose(1, 2) + # key_states = key_states.transpose(1, 2) + # value_states = value_states.transpose(1, 2) - # dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) - # cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) - # query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) + dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) + cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) batch_size = query_states.shape[0] scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 6c08f69..13f42b2 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -105,6 +105,7 @@ def __init__(self, config): torch.set_float32_matmul_precision('high') self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") + self.forward = torch.compile(self.forward, mode="reduce-overhead") def sample_noise(self, shape, device): noise = torch.normal( @@ -277,9 +278,7 @@ def forward(self, observation, actions, noise=None, time=None) -> Tensor: v_t = self.action_out_proj(suffix_out) - #losses = F.mse_loss(u_t, v_t, reduction="none") - - losses = torch.square(v_t - u_t) + losses = F.mse_loss(u_t, v_t, reduction="none") return losses @torch.no_grad() From e0d76fbf909c4f801ccd85065a343b233c442639 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 13:48:07 -0700 Subject: [PATCH 04/32] fix resume --- examples/convert_jax_model_to_pytorch.py | 10 +- scripts/train_pytorch.py | 94 +++++++----- src/openpi/models/model.py | 177 ++++++++++++++++++++++- src/openpi/models_pytorch/pi0_pytorch.py | 9 +- 4 files changed, 239 insertions(+), 51 deletions(-) diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index 8db8f1d..34c4d97 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -8,20 +8,20 @@ Usage: # Just inspect keys: - python convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only # Convert to PyTorch: - python convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output Example: # pi0_droid - python convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch2 + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch2 # pi0_aloha_sim - python convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch2 + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch2 # pi05_droid - python convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch2 + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch2 """ import argparse diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 78bd389..cdaeb17 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -21,10 +21,12 @@ python scripts/train_pytorch.py --exp_name --ckpt_save_interval Example: python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint Multi-GPU (single node): torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name Example: torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume Multi-Node Training: # On master node (node 0): torchrun --nnodes= --nproc_per_node= --rdzv_id= --rdzv_backend=c10d --rdzv_endpoint=: scripts/train_pytorch.py --exp_name @@ -65,6 +67,9 @@ by the selected TrainConfig (e.g., `LeRobot*` configs for real datasets or `FakeDataConfig`). - Supports Weights & Biases (wandb) logging for experiment tracking and visualization. - Checkpoints include model state, optimizer state, and training metadata for complete resume capability. +- Checkpoints are saved in experiment-specific directories: // +- Resume functionality automatically finds the latest checkpoint for the specified experiment name. +- Checkpoint loading handles both PyTorch and JAX/Flax checkpoints for compatibility. - For optimal multi-node performance, ensure high-bandwidth network connectivity (e.g., InfiniBand). - Monitor GPU utilization and network bandwidth during multi-node training. - Memory optimizations can significantly reduce GPU memory usage while maintaining training quality. @@ -239,27 +244,28 @@ def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_in # Only save if it's time to save or if it's the final step if (global_step % save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: - ckpt_dir = os.path.join(config.checkpoint_dir, f"{global_step}") - os.makedirs(ckpt_dir, exist_ok=True) + # Ensure checkpoint_dir is a Path object and create the step-specific directory + ckpt_dir = config.checkpoint_dir / f"{global_step}" + ckpt_dir.mkdir(parents=True, exist_ok=True) # Save model state state_dict = get_model_state_dict(model) - torch.save(state_dict, os.path.join(ckpt_dir, "pytorch_model.pt")) + torch.save(state_dict, ckpt_dir / "pytorch_model.pt") # Save optimizer state - torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer.pt")) + torch.save(optimizer.state_dict(), ckpt_dir / "optimizer.pt") # Save EMA state if available if ema_model is not None: - torch.save(ema_model.state_dict(), os.path.join(ckpt_dir, "ema_model.pt")) + torch.save(ema_model.state_dict(), ckpt_dir / "ema_model.pt") - # Save training metadata + # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) metadata = { "global_step": global_step, "config": dataclasses.asdict(config), "timestamp": time.time(), } - torch.save(metadata, os.path.join(ckpt_dir, "metadata.pt")) + torch.save(metadata, ckpt_dir / "metadata.pt") logging.info(f"Saved checkpoint at step {global_step} -> {ckpt_dir}") @@ -268,44 +274,49 @@ def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_in wandb.log({"checkpoint_step": global_step}, step=global_step) -def load_checkpoint(model, optimizer, config, device, ema_model=None): +def load_checkpoint(model, optimizer, checkpoint_dir, device, ema_model=None): """Load the latest checkpoint and return the global step.""" checkpoint_steps = [] - for d in config.checkpoint_dir.iterdir(): + for d in checkpoint_dir.iterdir(): if d.is_dir() and d.name.isdigit(): checkpoint_steps.append(int(d.name)) if not checkpoint_steps: - raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir}") + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") latest_step = max(checkpoint_steps) - ckpt_dir = os.path.join(config.checkpoint_dir, f"{latest_step}") + ckpt_dir = checkpoint_dir / f"{latest_step}" # Load model state - model_state_dict = torch.load(os.path.join(ckpt_dir, "pytorch_model.pt"), map_location=device) + model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device) (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) # Load optimizer state - optimizer_state_dict = torch.load(os.path.join(ckpt_dir, "optimizer.pt"), map_location=device) + optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device) optimizer.load_state_dict(optimizer_state_dict) # Load EMA state if available - if ema_model is not None and os.path.exists(os.path.join(ckpt_dir, "ema_model.pt")): - ema_state_dict = torch.load(os.path.join(ckpt_dir, "ema_model.pt"), map_location=device) + if ema_model is not None and (ckpt_dir / "ema_model.pt").exists(): + ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device) ema_model.load_state_dict(ema_state_dict) logging.info(f"Loaded EMA state from checkpoint") - # Load metadata - metadata = torch.load(os.path.join(ckpt_dir, "metadata.pt"), map_location=device) - - logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") - return metadata["global_step"] - - -def get_latest_checkpoint_step(config): - """Get the latest checkpoint step number.""" + # Load metadata (weights_only=False needed for older checkpoints that might contain JAX/Flax objects) + try: + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + global_step = metadata.get("global_step", latest_step) + logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") + return global_step + except Exception as e: + logging.warning(f"Failed to load metadata from checkpoint: {e}") + logging.warning("Using checkpoint step number as global step") + return latest_step + + +def get_latest_checkpoint_step(checkpoint_dir): + """Get the latest checkpoint step number from a checkpoint directory.""" checkpoint_steps = [] - for d in config.checkpoint_dir.iterdir(): + for d in checkpoint_dir.iterdir(): if d.is_dir() and d.name.isdigit(): checkpoint_steps.append(int(d.name)) @@ -386,7 +397,7 @@ def setup_memory_optimizations(model, device, enable_gradient_checkpointing=Fals logging.info(f"Cleared CUDA cache for device {device.index}") -def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False): +def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False): use_ddp, local_rank, device = setup_ddp() is_main = (not use_ddp) or (dist.get_rank() == 0) set_seed(config.seed, local_rank) @@ -397,24 +408,32 @@ def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, grad # Initialize checkpoint directory and wandb resuming = False - if config.resume: - # Check if checkpoint directory exists and has checkpoints - if config.checkpoint_dir.exists(): - latest_step = get_latest_checkpoint_step(config) + if resume: + # Find checkpoint directory based on experiment name + exp_checkpoint_dir = config.checkpoint_dir + if exp_checkpoint_dir.exists(): + latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) if latest_step is not None: resuming = True - logging.info(f"Resuming from checkpoint directory: {config.checkpoint_dir} at step {latest_step}") + logging.info(f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}") else: - raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir} for resume") + raise FileNotFoundError(f"No checkpoints found in {exp_checkpoint_dir} for resume") else: - raise FileNotFoundError(f"Checkpoint directory {config.checkpoint_dir} does not exist for resume") + raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") elif config.overwrite and config.checkpoint_dir.exists(): import shutil shutil.rmtree(config.checkpoint_dir) logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") - # Create checkpoint directory - config.checkpoint_dir.mkdir(parents=True, exist_ok=True) + # Create checkpoint directory with experiment name + if not resuming: + # For new runs, create experiment-specific checkpoint directory + exp_checkpoint_dir = config.checkpoint_dir + exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) + logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") + else: + # For resume, checkpoint_dir is already set to the experiment directory + logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") # Initialize wandb (only on main process) if is_main: @@ -534,7 +553,7 @@ def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, grad # Load checkpoint if resuming global_step = 0 if resuming: - global_step = load_checkpoint(model, optim, config, device, ema_model) + global_step = load_checkpoint(model, optim, config.checkpoint_dir, device, ema_model) logging.info(f"Resumed training from step {global_step}") def lr_schedule(step: int): @@ -692,6 +711,8 @@ def main(): # Parse additional command line arguments for memory optimization import argparse parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--resume", action="store_true", default=False, + help="Resume training from the latest checkpoint for the experiment (handles both PyTorch and JAX checkpoints)") parser.add_argument("--ckpt_save_interval", type=int, default=None, help="Interval for saving checkpoints (overrides config.save_interval)") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, @@ -716,6 +737,7 @@ def main(): os.environ["TORCH_DISTRIBUTED_DEBUG"] = args.ddp_debug_level train_loop(config, + resume=args.resume, ckpt_save_interval=args.ckpt_save_interval, gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=mixed_precision, diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index 81028ae..ce3c7de 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -207,6 +207,162 @@ def preprocess_observation( ) +def preprocess_observation_pytorch_torch_compile( + observation, + *, + train: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +): + """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. + + This function avoids complex type annotations that can cause torch.compile issues. + """ + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + for key in image_keys: + image = observation.images[key] + + # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats + # Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + image = image.permute(0, 2, 3, 1) + + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad_torch(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + image = image / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = image.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=image.device) + start_w = torch.randint(0, max_w + 1, (1,), device=image.device) + image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] + + # Resize back to original size + image = torch.nn.functional.interpolate( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode='bilinear', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = angle * torch.pi / 180.0 + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=image.device) + grid_y = torch.linspace(-1, 1, height, device=image.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + image = torch.nn.functional.grid_sample( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 + image = image * brightness_factor + + # Random contrast + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 + mean = image.mean(dim=[1, 2, 3], keepdim=True) + image = (image - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 + gray = image.mean(dim=-1, keepdim=True) + image = gray + (image - gray) * saturation_factor + + # Clamp values to [0, 1] + image = torch.clamp(image, 0, 1) + + # Back to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) + else: + out_masks[key] = observation.image_masks[key] + + # Create a simple object with the required attributes instead of using the complex Observation class + class SimpleProcessedObservation: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + return SimpleProcessedObservation( + images=out_images, + image_masks=out_masks, + state=observation.state, + tokenized_prompt=observation.tokenized_prompt, + tokenized_prompt_mask=observation.tokenized_prompt_mask, + token_ar_mask=observation.token_ar_mask, + token_loss_mask=observation.token_loss_mask, + ) + + def preprocess_observation_pytorch( observation: Observation, *, @@ -258,8 +414,9 @@ def preprocess_observation_pytorch( max_h = height - crop_height max_w = width - crop_width if max_h > 0 and max_w > 0: - start_h = torch.randint(0, max_h + 1, (1,)).item() - start_w = torch.randint(0, max_w + 1, (1,)).item() + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=image.device) + start_w = torch.randint(0, max_w + 1, (1,), device=image.device) image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] # Resize back to original size @@ -271,10 +428,11 @@ def preprocess_observation_pytorch( ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] # Random rotation (small angles) - angle = torch.rand(1).item() * 10 - 5 # Random angle between -5 and 5 degrees - if abs(angle) > 0.1: # Only rotate if angle is significant + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant # Convert to radians - angle_rad = torch.tensor(angle * torch.pi / 180.0, device=image.device) + angle_rad = angle * torch.pi / 180.0 # Create rotation matrix cos_a = torch.cos(angle_rad) @@ -308,17 +466,20 @@ def preprocess_observation_pytorch( # Color augmentations for all cameras # Random brightness - brightness_factor = 0.7 + torch.rand(1).item() * 0.6 # Random factor between 0.7 and 1.3 + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 image = image * brightness_factor # Random contrast - contrast_factor = 0.6 + torch.rand(1).item() * 0.8 # Random factor between 0.6 and 1.4 + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 mean = image.mean(dim=[1, 2, 3], keepdim=True) image = (image - mean) * contrast_factor + mean # Random saturation (convert to HSV, modify S, convert back) # For simplicity, we'll just apply a random scaling to the color channels - saturation_factor = 0.5 + torch.rand(1).item() * 1.0 # Random factor between 0.5 and 1.5 + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 gray = image.mean(dim=-1, keepdim=True) image = gray + (image - gray) * saturation_factor diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 13f42b2..73828e7 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -105,7 +105,7 @@ def __init__(self, config): torch.set_float32_matmul_precision('high') self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") - self.forward = torch.compile(self.forward, mode="reduce-overhead") + #self.forward = torch.compile(self.forward, mode="reduce-overhead") def sample_noise(self, shape, device): noise = torch.normal( @@ -234,7 +234,12 @@ def embed_suffix(self, state, noisy_actions, timestep): def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - observation = _model.preprocess_observation_pytorch(observation, train=True) + # # Use torch.compile-compatible preprocessing if we're in a compiled context + # if torch._dynamo.is_compiling(): + # observation = _model.preprocess_observation_pytorch_torch_compile(observation, train=True) + # else: + # observation = _model.preprocess_observation_pytorch(observation, train=True) + observation = _model.preprocess_observation_pytorch_torch_compile(observation, train=True) images = list(observation.images.values()) img_masks = list(observation.image_masks.values()) lang_tokens = observation.tokenized_prompt From 642bb4f5ff14815bd5a474c23f43bc8a70be8de9 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 14:31:07 -0700 Subject: [PATCH 05/32] add gradient checkpointing --- scripts/train_pytorch.py | 8 +++ src/openpi/models_pytorch/pi0_pytorch.py | 66 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index cdaeb17..822d8de 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -499,6 +499,14 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Apply memory optimizations setup_memory_optimizations(model, device, enable_gradient_checkpointing) + # Log gradient checkpointing status if enabled + if enable_gradient_checkpointing and is_main: + if hasattr(model, 'get_gradient_checkpointing_status'): + status = model.get_gradient_checkpointing_status() + logging.info(f"Gradient checkpointing status: {status}") + else: + logging.info("Gradient checkpointing enabled but status check not available") + # Check model parameters for debugging if is_main: check_model_parameters(model, device) diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 73828e7..5f13a66 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -1,4 +1,5 @@ import math +import logging import torch from torch import Tensor @@ -106,6 +107,71 @@ def __init__(self, config): torch.set_float32_matmul_precision('high') self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") #self.forward = torch.compile(self.forward, mode="reduce-overhead") + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + + # Enable gradient checkpointing in the underlying models + if hasattr(self.paligemma_with_expert, 'paligemma'): + if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing in PaliGemma language model") + + if hasattr(self.paligemma_with_expert, 'gemma_expert'): + if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing in Gemma expert model") + + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + + # Disable gradient checkpointing in the underlying models + if hasattr(self.paligemma_with_expert, 'paligemma'): + if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + + if hasattr(self.paligemma_with_expert, 'gemma_expert'): + if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def is_gradient_checkpointing_enabled(self): + """Check if gradient checkpointing is enabled.""" + return self.gradient_checkpointing_enabled + + def get_gradient_checkpointing_status(self): + """Get detailed gradient checkpointing status of underlying models.""" + status = { + 'main_model': self.gradient_checkpointing_enabled, + 'paligemma_language_model': False, + 'gemma_expert_model': False + } + + if hasattr(self.paligemma_with_expert, 'paligemma'): + if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): + status['paligemma_language_model'] = getattr( + self.paligemma_with_expert.paligemma.language_model, + 'gradient_checkpointing', + False + ) + + if hasattr(self.paligemma_with_expert, 'gemma_expert'): + if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): + status['gemma_expert_model'] = getattr( + self.paligemma_with_expert.gemma_expert.model, + 'gradient_checkpointing', + False + ) + + return status def sample_noise(self, shape, device): noise = torch.normal( From 3bd6bdec75978b79c054baba2cd4b15578ea0a12 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 14:35:09 -0700 Subject: [PATCH 06/32] fix gradient checkpointing --- scripts/train_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 822d8de..c8475a7 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -731,7 +731,7 @@ def main(): help="Disable mixed precision training") parser.add_argument("--max_memory_usage", type=float, default=None, help="Maximum GPU memory usage in GB (default: None, auto-detect)") - parser.add_argument("--enable_gradient_checkpointing", action="store_true", default=False, + parser.add_argument("--gradckpt", action="store_true", default=False, help="Enable gradient checkpointing for memory optimization") parser.add_argument("--ddp_debug_level", type=str, default="INFO", choices=["INFO", "DETAIL", "OFF"], help="DDP debugging level (default: INFO)") @@ -750,7 +750,7 @@ def main(): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=mixed_precision, max_memory_usage=args.max_memory_usage, - enable_gradient_checkpointing=args.enable_gradient_checkpointing) + enable_gradient_checkpointing=args.gradckpt) if __name__ == "__main__": From 5789f0dc9a370da4ddc24938374e557ff378e239 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sun, 24 Aug 2025 23:27:21 -0700 Subject: [PATCH 07/32] batch size 16 working --- scripts/train_pytorch.py | 336 +++++++++++++++++++-- src/openpi/models_pytorch/gemma_pytorch.py | 106 +++++-- src/openpi/models_pytorch/pi0_pytorch.py | 226 +++++++++++--- src/openpi/training/config.py | 6 +- 4 files changed, 577 insertions(+), 97 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index c8475a7..438f28e 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -56,11 +56,13 @@ Checkpoint Parameters: - --ckpt_save_interval: Override the checkpoint save interval from config (e.g., --save_interval 500) - --resume: Resume training from the latest checkpoint in the checkpoint directory +- --cleanup_checkpoints: Clean up corrupted checkpoints during resume (keeps last 3 valid ones) - --overwrite: Overwrite existing checkpoint directory (cannot be used with --resume) Memory Optimization Parameters: - --gradient_accumulation_steps: Number of steps to accumulate gradients (default: 1) - --mixed_precision: Enable mixed precision training (default: True) - --max_memory_usage: Maximum GPU memory usage in GB (default: None, auto-detect) +- --gradckpt: Enable gradient checkpointing for memory optimization Notes - The global batch size must be divisible by world size (number of processes). - The data pipeline and transforms are identical to the JAX version and are controlled @@ -80,6 +82,7 @@ import os import platform import time +import gc from dataclasses import dataclass from typing import Any, Dict, Tuple @@ -287,25 +290,49 @@ def load_checkpoint(model, optimizer, checkpoint_dir, device, ema_model=None): latest_step = max(checkpoint_steps) ckpt_dir = checkpoint_dir / f"{latest_step}" - # Load model state - model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device) - (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) + # Load model state with error handling + try: + model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) + (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) + logging.info(f"Successfully loaded model state from step {latest_step}") + except Exception as e: + logging.error(f"Failed to load model state from step {latest_step}: {e}") + raise RuntimeError(f"Model checkpoint corrupted at step {latest_step}. Cannot resume training.") - # Load optimizer state - optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device) - optimizer.load_state_dict(optimizer_state_dict) + # Load optimizer state with error handling and fallback + optimizer_loaded = False + try: + optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) + optimizer.load_state_dict(optimizer_state_dict) + optimizer_loaded = True + logging.info(f"Successfully loaded optimizer state from step {latest_step}") + except Exception as e: + logging.warning(f"Failed to load optimizer state from step {latest_step}: {e}") + logging.warning("Optimizer state corrupted. Will continue with fresh optimizer state.") + # Reset optimizer to fresh state + for param_group in optimizer.param_groups: + param_group['lr'] = param_group.get('lr', 1e-4) # Use default LR or current LR + optimizer.zero_grad() + optimizer_loaded = False # Load EMA state if available + ema_loaded = False if ema_model is not None and (ckpt_dir / "ema_model.pt").exists(): - ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device) - ema_model.load_state_dict(ema_state_dict) - logging.info(f"Loaded EMA state from checkpoint") + try: + ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device, weights_only=False) + ema_model.load_state_dict(ema_state_dict) + ema_loaded = True + logging.info(f"Successfully loaded EMA state from step {latest_step}") + except Exception as e: + logging.warning(f"Failed to load EMA state from step {latest_step}: {e}") + logging.warning("EMA state corrupted. Will continue without EMA.") + ema_loaded = False # Load metadata (weights_only=False needed for older checkpoints that might contain JAX/Flax objects) try: metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) global_step = metadata.get("global_step", latest_step) - logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") + logging.info(f"Successfully loaded metadata from step {latest_step}") return global_step except Exception as e: logging.warning(f"Failed to load metadata from checkpoint: {e}") @@ -323,6 +350,110 @@ def get_latest_checkpoint_step(checkpoint_dir): return max(checkpoint_steps) if checkpoint_steps else None +def validate_checkpoint_integrity(checkpoint_dir, step): + """Validate that a checkpoint at the given step is complete and uncorrupted.""" + ckpt_dir = checkpoint_dir / f"{step}" + + required_files = ["pytorch_model.pt", "optimizer.pt", "metadata.pt"] + optional_files = ["ema_model.pt"] + + # Check if all required files exist + for file_name in required_files: + file_path = ckpt_dir / file_name + if not file_path.exists(): + logging.warning(f"Required checkpoint file missing: {file_path}") + return False + + # Try to validate file integrity by attempting to load them + try: + # Test model file + device = torch.device("cpu") # Use CPU for validation to avoid GPU memory issues + model_state = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) + if not isinstance(model_state, dict): + logging.warning(f"Model checkpoint file corrupted at step {step}") + return False + + # Test optimizer file + optimizer_state = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) + if not isinstance(optimizer_state, dict): + logging.warning(f"Optimizer checkpoint file corrupted at step {step}") + return False + + # Test metadata file + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + if not isinstance(metadata, dict) or "global_step" not in metadata: + logging.warning(f"Metadata checkpoint file corrupted at step {step}") + return False + + logging.info(f"Checkpoint at step {step} validated successfully") + return True + + except Exception as e: + logging.warning(f"Checkpoint validation failed at step {step}: {e}") + return False + + +def find_latest_valid_checkpoint(checkpoint_dir): + """Find the latest checkpoint that passes integrity validation.""" + checkpoint_steps = [] + for d in checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + if not checkpoint_steps: + return None + + # Sort steps in descending order to check latest first + checkpoint_steps.sort(reverse=True) + + for step in checkpoint_steps: + if validate_checkpoint_integrity(checkpoint_dir, step): + return step + + logging.error("No valid checkpoints found in directory") + return None + + +def cleanup_corrupted_checkpoints(checkpoint_dir, keep_last_n=3): + """Clean up corrupted checkpoints, keeping only the last N valid ones.""" + checkpoint_steps = [] + for d in checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + if not checkpoint_steps: + return + + # Sort steps in descending order + checkpoint_steps.sort(reverse=True) + + valid_checkpoints = [] + corrupted_checkpoints = [] + + # Validate all checkpoints + for step in checkpoint_steps: + if validate_checkpoint_integrity(checkpoint_dir, step): + valid_checkpoints.append(step) + else: + corrupted_checkpoints.append(step) + + # Keep only the last N valid checkpoints + checkpoints_to_keep = valid_checkpoints[:keep_last_n] + checkpoints_to_remove = valid_checkpoints[keep_last_n:] + corrupted_checkpoints + + # Remove old valid checkpoints and all corrupted ones + for step in checkpoints_to_remove: + checkpoint_path = checkpoint_dir / f"{step}" + try: + import shutil + shutil.rmtree(checkpoint_path) + logging.info(f"Removed checkpoint at step {step}") + except Exception as e: + logging.warning(f"Failed to remove checkpoint at step {step}: {e}") + + logging.info(f"Checkpoint cleanup complete. Kept {len(checkpoints_to_keep)} valid checkpoints: {checkpoints_to_keep}") + + def debug_unused_parameters(model, device): """Debug function to identify unused parameters in the model.""" if isinstance(model, DDP): @@ -368,14 +499,41 @@ def check_model_parameters(model, device): if unused_params: logging.warning(f"Found {len(unused_params)} parameters that might be unused:") - for name in unused_params[:10]: # Show first 10 + for name in unused_params: # Show first 10 logging.warning(f" - {name}") - if len(unused_params) > 10: - logging.warning(f" ... and {len(unused_params) - 10} more") + # if len(unused_params) > 10: + # logging.warning(f" ... and {len(unused_params) - 10} more") + + +def log_memory_usage(device, step, phase="unknown"): + """Log detailed memory usage information.""" + if not torch.cuda.is_available(): + return + + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 + memory_reserved = torch.cuda.memory_reserved(device) / 1e9 + memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) + memory_free = memory_free / 1e9 + + # Get more detailed memory info + memory_stats = torch.cuda.memory_stats(device) + max_memory_allocated = memory_stats.get('allocated_bytes.all.peak', 0) / 1e9 + max_memory_reserved = memory_stats.get('reserved_bytes.all.peak', 0) / 1e9 + + # Get DDP info if available + ddp_info = "" + if dist.is_initialized(): + ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" + + logging.info(f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}") def setup_memory_optimizations(model, device, enable_gradient_checkpointing=False): """Setup memory optimization techniques for the model.""" + # Set memory optimization environment variables + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + os.environ["CUDA_LAUNCH_BLOCKING"] = "0" + if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): model.gradient_checkpointing_enable() logging.info("Enabled gradient checkpointing for memory optimization") @@ -388,16 +546,18 @@ def setup_memory_optimizations(model, device, enable_gradient_checkpointing=Fals # Set memory efficient settings if torch.cuda.is_available(): # Enable memory efficient algorithms - torch.backends.cudnn.benchmark = True - torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = False # Disable for memory efficiency + torch.backends.cudnn.deterministic = True # Enable for memory efficiency # Set memory fraction if needed if device.index is not None: torch.cuda.empty_cache() logging.info(f"Cleared CUDA cache for device {device.index}") + + logging.info("Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to reduce memory fragmentation") -def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False): +def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False, cleanup_checkpoints: bool = False): use_ddp, local_rank, device = setup_ddp() is_main = (not use_ddp) or (dist.get_rank() == 0) set_seed(config.seed, local_rank) @@ -412,12 +572,18 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Find checkpoint directory based on experiment name exp_checkpoint_dir = config.checkpoint_dir if exp_checkpoint_dir.exists(): - latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) + # Use validation to find the latest working checkpoint + latest_step = find_latest_valid_checkpoint(exp_checkpoint_dir) if latest_step is not None: resuming = True logging.info(f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}") + + # Clean up corrupted checkpoints if requested + if cleanup_checkpoints and is_main: + logging.info("Cleaning up corrupted checkpoints...") + cleanup_corrupted_checkpoints(exp_checkpoint_dir, keep_last_n=3) else: - raise FileNotFoundError(f"No checkpoints found in {exp_checkpoint_dir} for resume") + raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") else: raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") elif config.overwrite and config.checkpoint_dir.exists(): @@ -449,10 +615,8 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte effective_batch_size = config.batch_size // (dist.get_world_size() if use_ddp else 1) # Memory-efficient data loading with reduced pin_memory for large datasets - pin_memory = True - if effective_batch_size > 16: # Reduce pin_memory for large batches - pin_memory = False - logging.info("Disabled pin_memory for large batch size to reduce memory usage") + pin_memory = False # Disable pin_memory to reduce memory usage + logging.info("Disabled pin_memory to reduce memory usage") loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) @@ -480,6 +644,8 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Reset the loader iterator loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) + # Test gradient checkpointing with a small forward pass (moved to after model creation) + # Build model if not isinstance(config.model, Pi0Config): # Convert dataclass to Pi0Config if needed @@ -499,17 +665,77 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Apply memory optimizations setup_memory_optimizations(model, device, enable_gradient_checkpointing) + # Log initial memory usage after model creation + if is_main and torch.cuda.is_available(): + log_memory_usage(device, 0, "after_model_creation") + # Log gradient checkpointing status if enabled if enable_gradient_checkpointing and is_main: if hasattr(model, 'get_gradient_checkpointing_status'): status = model.get_gradient_checkpointing_status() logging.info(f"Gradient checkpointing status: {status}") + + # Verify that gradient checkpointing is actually enabled + if hasattr(model, 'is_gradient_checkpointing_enabled'): + is_enabled = model.is_gradient_checkpointing_enabled() + logging.info(f"Gradient checkpointing is enabled: {is_enabled}") + + # Check if we're in training mode + logging.info(f"Model training mode: {model.training}") + + # Verify the underlying models have gradient checkpointing enabled + if hasattr(model, 'paligemma_with_expert'): + if hasattr(model.paligemma_with_expert, 'paligemma'): + if hasattr(model.paligemma_with_expert.paligemma, 'language_model'): + paligemma_gc = getattr(model.paligemma_with_expert.paligemma.language_model, 'gradient_checkpointing', False) + logging.info(f"PaliGemma language model gradient checkpointing: {paligemma_gc}") + + if hasattr(model.paligemma_with_expert.paligemma, 'vision_tower'): + vision_gc = getattr(model.paligemma_with_expert.paligemma.vision_tower, 'gradient_checkpointing', False) + logging.info(f"PaliGemma vision tower gradient checkpointing: {vision_gc}") + + if hasattr(model.paligemma_with_expert, 'gemma_expert'): + if hasattr(model.paligemma_with_expert.gemma_expert, 'model'): + gemma_gc = getattr(model.paligemma_with_expert.gemma_expert.model, 'gradient_checkpointing', False) + logging.info(f"Gemma expert model gradient checkpointing: {gemma_gc}") else: logging.info("Gradient checkpointing enabled but status check not available") - # Check model parameters for debugging - if is_main: - check_model_parameters(model, device) + # Test gradient checkpointing with a small forward pass + if is_main and enable_gradient_checkpointing: + logging.info("Testing gradient checkpointing with a small forward pass...") + try: + # Create a small test batch + test_batch = next(iter(loader)) + test_batch = batch_to_torch(test_batch, device) + test_actions = test_batch["actions"] + + # Record memory before forward pass + if torch.cuda.is_available(): + memory_before = torch.cuda.memory_allocated(device) / 1e9 + logging.info(f"Memory before test forward pass: {memory_before:.2f}GB") + + # Do a test forward pass + with torch.no_grad(): + test_observation = _model.Observation.from_dict(test_batch) + test_losses = model(test_observation, test_actions) + + # Record memory after forward pass + if torch.cuda.is_available(): + memory_after = torch.cuda.memory_allocated(device) / 1e9 + logging.info(f"Memory after test forward pass: {memory_after:.2f}GB") + logging.info(f"Memory difference: {memory_after - memory_before:.2f}GB") + + # Clear test data + del test_batch, test_actions, test_observation, test_losses + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + logging.info("Gradient checkpointing test completed successfully") + except Exception as e: + logging.warning(f"Gradient checkpointing test failed: {e}") + logging.warning("Continuing with training...") if use_ddp: # Enable unused parameter detection to handle cases where some parameters don't participate in loss @@ -574,6 +800,17 @@ def lr_schedule(step: int): # Enable mixed precision training for memory optimization scaler = torch.amp.GradScaler(enabled=mixed_precision and torch.cuda.is_available()) + + # Set memory efficient settings + if torch.cuda.is_available(): + # Enable memory efficient algorithms + torch.backends.cudnn.benchmark = False # Disable for memory efficiency + torch.backends.cudnn.deterministic = True # Enable for memory efficiency + + # Set memory fraction if needed + if device.index is not None: + torch.cuda.empty_cache() + logging.info(f"Cleared CUDA cache for device {device.index}") model.train() start_time = time.time() @@ -590,6 +827,9 @@ def lr_schedule(step: int): # Training loop - iterate until we reach num_train_steps pbar = tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None + # Check model parameters after first few steps when gradients are available + parameters_checked = False + while global_step < config.num_train_steps: if use_ddp: sampler.set_epoch(global_step // len(loader)) @@ -619,6 +859,14 @@ def lr_schedule(step: int): losses = torch.tensor(losses, device=device, dtype=torch.float32) loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation + + # Debug gradient checkpointing on first few steps + if global_step < 5 and is_main: + if hasattr(model, 'is_gradient_checkpointing_enabled'): + gc_enabled = model.is_gradient_checkpointing_enabled() + logging.info(f"Step {global_step}: Gradient checkpointing enabled: {gc_enabled}") + if torch.cuda.is_available(): + log_memory_usage(device, global_step, "after_forward") except RuntimeError as e: if "Expected to have finished reduction" in str(e) or "did not receive grad" in str(e): logging.error(f"DDP error on rank {dist.get_rank() if use_ddp else 0}: {e}") @@ -630,6 +878,16 @@ def lr_schedule(step: int): # Backward pass with gradient scaling scaler.scale(loss).backward() + + # Aggressive memory cleanup after backward pass + if torch.cuda.is_available(): + # Clear intermediate activations that might still be in memory + torch.cuda.empty_cache() + gc.collect() + + # Log memory usage after backward pass for debugging + if global_step < 5 and is_main: + log_memory_usage(device, global_step, "after_backward") # Gradient accumulation logic if (global_step + 1) % gradient_accumulation_steps == 0: @@ -641,6 +899,12 @@ def lr_schedule(step: int): scaler.step(optim) scaler.update() optim.zero_grad(set_to_none=True) + + # Clear gradients more aggressively + for param in model.parameters(): + if param.grad is not None: + param.grad.detach_() + param.grad = None # Update EMA if enabled if ema_model is not None: @@ -654,6 +918,11 @@ def lr_schedule(step: int): logging.warning(f"Failed to update EMA model: {e}") # Continue training without EMA update + # # Check model parameters after first few steps when gradients are available + # if not parameters_checked and global_step >= 16510 and is_main: + # check_model_parameters(model, device) + # parameters_checked = True + # Collect stats (only on accumulation steps) if (global_step + 1) % gradient_accumulation_steps == 0 and is_main: infos.append({ @@ -671,7 +940,7 @@ def lr_schedule(step: int): logging.info(f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s") # Log to wandb - if config.wandb_enabled: + if config.wandb_enabled and len(infos) > 1: wandb.log({ "loss": avg_loss, "learning_rate": avg_lr, @@ -698,8 +967,18 @@ def lr_schedule(step: int): # Memory cleanup after each batch del batch, actions, observation, losses, loss + + # More aggressive memory cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() + # Force garbage collection + gc.collect() + + # Log memory usage for debugging gradient checkpointing + if is_main and global_step % 100 == 0: + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 + memory_reserved = torch.cuda.memory_reserved(device) / 1e9 + logging.info(f"Step {global_step}: GPU memory allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB") # Close progress bar if pbar is not None: @@ -721,6 +1000,8 @@ def main(): parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--resume", action="store_true", default=False, help="Resume training from the latest checkpoint for the experiment (handles both PyTorch and JAX checkpoints)") + parser.add_argument("--cleanup_checkpoints", action="store_true", default=False, + help="Clean up corrupted checkpoints during resume (keeps last 3 valid checkpoints)") parser.add_argument("--ckpt_save_interval", type=int, default=None, help="Interval for saving checkpoints (overrides config.save_interval)") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, @@ -750,7 +1031,8 @@ def main(): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=mixed_precision, max_memory_usage=args.max_memory_usage, - enable_gradient_checkpointing=args.gradckpt) + enable_gradient_checkpointing=True, + cleanup_checkpoints=args.cleanup_checkpoints) if __name__ == "__main__": diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 5a27cfa..5737a80 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -55,7 +55,8 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" + vlm_config_hf.vision_config.torch_dtype = "bfloat16" + vlm_config_hf.vision_config._attn_implementation = "flash_attention_2" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, @@ -135,7 +136,38 @@ def forward( else: models = [self.paligemma.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers - for layer_idx in range(num_layers): + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, 'gradient_checkpointing') and + self.gemma_expert.model.gradient_checkpointing and + self.training + ) or ( + hasattr(self, 'gradient_checkpointing') and + self.gradient_checkpointing and + self.training + ) + + # Force enable gradient checkpointing if we're in training mode and the model supports it + if self.training and hasattr(self.gemma_expert.model, 'gradient_checkpointing'): + if not self.gemma_expert.model.gradient_checkpointing: + print("Forcing gradient checkpointing to be enabled for Gemma expert model") + self.gemma_expert.model.gradient_checkpointing = True + use_gradient_checkpointing = True + + # Debug gradient checkpointing status + if hasattr(self, '_debug_gc_printed') and not self._debug_gc_printed: + print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}") + print(f"Model training mode: {self.training}") + print(f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}") + if hasattr(self.gemma_expert.model, 'gradient_checkpointing'): + print(f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}") + self._debug_gc_printed = True + + # Define the complete layer computation function for gradient checkpointing + def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): + models = [self.paligemma.language_model, self.gemma_expert.model] + query_states = [] key_states = [] value_states = [] @@ -143,9 +175,6 @@ def forward( for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) - # hidden_states = hidden_states.to(dtype=torch.bfloat16) - # if gate is not None: - # gate = gate.to(dtype=torch.bfloat16) gates.append(gate) input_shape = hidden_states.shape[:-1] @@ -158,42 +187,36 @@ def forward( key_states.append(key_state) value_states.append(value_state) - # B,L,H,D with L sequence length, H number of heads, D head dim - # concatenate on the number of embeddings/tokens + # Concatenate and process attention query_states = torch.cat(query_states, dim=2) key_states = torch.cat(key_states, dim=2) value_states = torch.cat(value_states, dim=2) - # query_states = apply_rope(query_states, position_ids) - # key_states = apply_rope(key_states, position_ids) - - # query_states = query_states.transpose(1, 2) - # key_states = key_states.transpose(1, 2) - # value_states = value_states.transpose(1, 2) - dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) batch_size = query_states.shape[0] scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling + + # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( self.paligemma.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling ) - #att_output = att_output.to(dtype=torch.bfloat16) - att_output = att_output.reshape(batch_size, -1, 1 * 8 * layer.self_attn.head_dim) - + # Get head_dim from the current layer, not from the model + head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) - # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) + # Process layer outputs outputs_embeds = [] - start = 0 + start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] - end = start + hidden_states.shape[1] + end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) - out_emb = layer.self_attn.o_proj(att_output[:, start:end]) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) @@ -205,14 +228,43 @@ def forward( # second residual out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) outputs_embeds.append(out_emb) - start = end - inputs_embeds = outputs_embeds + start_pos = end_pos + + return outputs_embeds + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, + use_reentrant=False, + preserve_rng_state=False + ) + else: + inputs_embeds = compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond) + + # Old code removed - now using compute_layer_complete function above # final norm - outputs_embeds = [] - for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) - outputs_embeds.append(out_emb) + # Define final norm computation function for gradient checkpointing + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, adarms_cond, + use_reentrant=False, + preserve_rng_state=False + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) prefix_output = outputs_embeds[0] suffix_output = outputs_embeds[1] diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 5f13a66..8dff228 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -119,14 +119,19 @@ def gradient_checkpointing_enable(self): if hasattr(self.paligemma_with_expert, 'paligemma'): if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - logging.info("Enabled gradient checkpointing in PaliGemma language model") + print("Enabled gradient checkpointing in PaliGemma language model") + + # Enable gradient checkpointing in the vision model (SigLIP) + if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + print("Enabled gradient checkpointing in PaliGemma vision tower (SigLIP)") if hasattr(self.paligemma_with_expert, 'gemma_expert'): if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True - logging.info("Enabled gradient checkpointing in Gemma expert model") + print("Enabled gradient checkpointing in Gemma expert model") - logging.info("Enabled gradient checkpointing for PI0Pytorch model") + print("Enabled gradient checkpointing for PI0Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" @@ -136,6 +141,10 @@ def gradient_checkpointing_disable(self): if hasattr(self.paligemma_with_expert, 'paligemma'): if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + + # Disable gradient checkpointing in the vision model (SigLIP) + if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False if hasattr(self.paligemma_with_expert, 'gemma_expert'): if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): @@ -152,6 +161,7 @@ def get_gradient_checkpointing_status(self): status = { 'main_model': self.gradient_checkpointing_enabled, 'paligemma_language_model': False, + 'paligemma_vision_model': False, 'gemma_expert_model': False } @@ -162,6 +172,13 @@ def get_gradient_checkpointing_status(self): 'gradient_checkpointing', False ) + + if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): + status['paligemma_vision_model'] = getattr( + self.paligemma_with_expert.paligemma.vision_tower, + 'gradient_checkpointing', + False + ) if hasattr(self.paligemma_with_expert, 'gemma_expert'): if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): @@ -198,25 +215,64 @@ def embed_prefix( pad_masks = [] att_masks = [] - for ( - img, - img_mask, - ) in zip(images, img_masks): - img_emb = self.paligemma_with_expert.embed_image(img) - - bsize, num_img_embs = img_emb.shape[:2] - img_mask = img_mask[:, None].expand(bsize, num_img_embs) - - embs.append(img_emb) - pad_masks.append(img_mask) + # Apply gradient checkpointing to image embedding if enabled + if self.gradient_checkpointing_enabled and self.training: + for ( + img, + img_mask, + ) in zip(images, img_masks): + # Use checkpoint for image embedding + def checkpointed_image_embed(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = torch.utils.checkpoint.checkpoint( + checkpointed_image_embed, + img, + use_reentrant=False, + preserve_rng_state=False + ) - # Create attention masks so that image tokens attend to each other - att_masks += [0] * num_img_embs + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + embs.append(img_emb) + pad_masks.append(img_mask) - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + else: + for ( + img, + img_mask, + ) in zip(images, img_masks): + img_emb = self.paligemma_with_expert.embed_image(img) + + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) + + embs.append(img_emb) + pad_masks.append(img_mask) + + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + + # Apply gradient checkpointing to language embedding if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_lang_embed(lang_tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = torch.utils.checkpoint.checkpoint( + checkpointed_lang_embed, + lang_tokens, + use_reentrant=False, + preserve_rng_state=False + ) + else: + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) embs.append(lang_emb) pad_masks.append(lang_masks) @@ -243,8 +299,20 @@ def embed_suffix(self, state, noisy_actions, timestep): att_masks = [] if not self.pi05: - # Embed state - state_emb = self.state_proj(state) + # Embed state with gradient checkpointing if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_state_proj(state): + return self.state_proj(state) + + state_emb = torch.utils.checkpoint.checkpoint( + checkpointed_state_proj, + state, + use_reentrant=False, + preserve_rng_state=False + ) + else: + state_emb = self.state_proj(state) + embs.append(state_emb[:, None, :]) bsize = state_emb.shape[0] dtype = state_emb.dtype @@ -262,22 +330,66 @@ def embed_suffix(self, state, noisy_actions, timestep): ) time_emb = time_emb.type(dtype=timestep.dtype) - # Fuse timestep + action information using an MLP - action_emb = self.action_in_proj(noisy_actions) + # Fuse timestep + action information using an MLP with gradient checkpointing if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_action_proj(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = torch.utils.checkpoint.checkpoint( + checkpointed_action_proj, + noisy_actions, + use_reentrant=False, + preserve_rng_state=False + ) + else: + action_emb = self.action_in_proj(noisy_actions) if not self.pi05: time_emb = time_emb[:, None, :].expand_as(action_emb) action_time_emb = torch.cat([action_emb, time_emb], dim=2) - action_time_emb = self.action_time_mlp_in(action_time_emb) - action_time_emb = F.silu(action_time_emb) # swish == silu - action_time_emb = self.action_time_mlp_out(action_time_emb) + + # Apply gradient checkpointing to MLP layers if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_mlp(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) # swish == silu + x = self.action_time_mlp_out(x) + return x + + action_time_emb = torch.utils.checkpoint.checkpoint( + checkpointed_mlp, + action_time_emb, + use_reentrant=False, + preserve_rng_state=False + ) + else: + action_time_emb = self.action_time_mlp_in(action_time_emb) + action_time_emb = F.silu(action_time_emb) # swish == silu + action_time_emb = self.action_time_mlp_out(action_time_emb) + adarms_cond = None else: - # time MLP (for adaRMS) - time_emb = self.time_mlp_in(time_emb) - time_emb = F.silu(time_emb) # swish == silu - time_emb = self.time_mlp_out(time_emb) - time_emb = F.silu(time_emb) + # time MLP (for adaRMS) with gradient checkpointing if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_time_mlp(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) # swish == silu + x = self.time_mlp_out(x) + x = F.silu(x) + return x + + time_emb = torch.utils.checkpoint.checkpoint( + checkpointed_time_mlp, + time_emb, + use_reentrant=False, + preserve_rng_state=False + ) + else: + time_emb = self.time_mlp_in(time_emb) + time_emb = F.silu(time_emb) # swish == silu + time_emb = self.time_mlp_out(time_emb) + time_emb = F.silu(time_emb) + action_time_emb = action_emb adarms_cond = time_emb @@ -336,18 +448,52 @@ def forward(self, observation, actions, noise=None, time=None) -> Tensor: att_2d_masks_4d = att_2d_masks[:, None, :, :] att_2d_masks_4d = torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks_4d, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond] - ) + # Apply gradient checkpointing if enabled + if self.gradient_checkpointing_enabled and self.training: + # Use torch.utils.checkpoint.checkpoint for the expensive forward pass + def checkpointed_forward(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond] + ) + return suffix_out + + suffix_out = torch.utils.checkpoint.checkpoint( + checkpointed_forward, + prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond, + use_reentrant=False, # More memory efficient + preserve_rng_state=False # More memory efficient + ) + else: + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond] + ) + suffix_out = suffix_out[:, -self.config.action_horizon :] suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) + # Apply gradient checkpointing to final action projection if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_action_out_proj(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = torch.utils.checkpoint.checkpoint( + checkpointed_action_out_proj, + suffix_out, + use_reentrant=False, + preserve_rng_state=False + ) + else: + v_t = self.action_out_proj(suffix_out) losses = F.mse_loss(u_t, v_t, reduction="none") return losses diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 6235e32..f866bee 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -746,7 +746,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=64, + batch_size=256, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -756,7 +756,7 @@ def __post_init__(self) -> None: optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", - num_train_steps=30_000, + num_train_steps=60_000, ), # # Fine-tuning Aloha configs. @@ -952,4 +952,4 @@ def get_config(config_name: str) -> TrainConfig: closest_str = f" Did you mean '{closest[0]}'? " if closest else "" raise ValueError(f"Config '{config_name}' not found.{closest_str}") - return _CONFIGS_DICT[config_name] \ No newline at end of file + return _CONFIGS_DICT[config_name] From bcce0ed5ccfb4ffbcf525145a0ff4d46265594a8 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Mon, 25 Aug 2025 20:58:32 -0700 Subject: [PATCH 08/32] bs 1024 (2 nodes) working --- src/openpi/models_pytorch/gemma_pytorch.py | 4 ++-- src/openpi/training/config.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 5737a80..3d70ca0 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -55,8 +55,8 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "bfloat16" - vlm_config_hf.vision_config._attn_implementation = "flash_attention_2" + vlm_config_hf.vision_config.torch_dtype = "float32" + # vlm_config_hf.vision_config._attn_implementation = "flash_attention_2" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index f866bee..e1ec12c 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -746,17 +746,17 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=256, + batch_size=1024, lr_schedule=_optimizer.CosineDecaySchedule( - warmup_steps=10_000, - peak_lr=5e-5, + warmup_steps=2_500, + peak_lr=1e-4, decay_steps=1_000_000, - decay_lr=5e-5, + decay_lr=1e-4, ), optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", - num_train_steps=60_000, + num_train_steps=6_500, ), # # Fine-tuning Aloha configs. From f554bc8de084e78413d993268a63274548a2cf9d Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 00:57:32 -0700 Subject: [PATCH 09/32] support finetuning --- examples/convert_jax_model_to_pytorch.py | 6 + scripts/train_pytorch.py | 619 +++++++++++++++++++++ scripts/train_single_example.py | 548 ++++++++++++++++++ src/openpi/models/pi0.py | 20 +- src/openpi/models_pytorch/gemma_pytorch.py | 124 ++++- src/openpi/models_pytorch/pi0_pytorch.py | 17 +- src/openpi/training/config.py | 24 +- src/openpi/training/data_loader.py | 1 + uv.lock | 356 +++++++----- 9 files changed, 1565 insertions(+), 150 deletions(-) create mode 100644 scripts/train_pytorch.py create mode 100644 scripts/train_single_example.py diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index 219402d..24facc1 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -536,6 +536,12 @@ def __init__(self): action_horizon=10, pi05=True, ) + elif "pi05_base" in checkpoint_dir: + pi0_config = Pi0Config( + action_dim=32, + action_horizon=50, + pi05=True, + ) else: pi0_config = openpi.models.pi0_config.Pi0Config( action_dim=8, diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py new file mode 100644 index 0000000..1a46fe1 --- /dev/null +++ b/scripts/train_pytorch.py @@ -0,0 +1,619 @@ +""" +PyTorch training entrypoint for PI0 with multi-GPU and multi-node (DDP) support. +This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs +entirely in PyTorch using the `PI0Pytorch` model and your existing config/data +pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. +Key features +- Uses the same TrainConfig/tyro CLI as the JAX script (see available configs in + `src/openpi/training/config.py`). +- Supports multi-GPU and multi-node training via DistributedDataParallel (DDP). +- Cosine LR with warmup (parameters read from the selected config). +- AdamW optimizer and gradient clipping. +- Comprehensive checkpoint saving and resume mechanism with configurable intervals. +- Checkpoints saved on rank 0 to `config.checkpoint_dir//` containing model, optimizer, and metadata. +- Memory optimizations: mixed precision training, gradient accumulation, and efficient data handling. +Requirements +- PyTorch >= 2.0, torch.distributed (NCCL for CUDA, Gloo for CPU). +- Multiple GPUs for DDP (optional). +- Network connectivity between nodes for multi-node training. +Usage +Single GPU: + python scripts/train_pytorch.py --exp_name --ckpt_save_interval + Example: + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test +Multi-GPU (single node): + torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name + Example: + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test +Multi-Node Training: + # On master node (node 0): + torchrun --nnodes= --nproc_per_node= --rdzv_id= --rdzv_backend=c10d --rdzv_endpoint=: scripts/train_pytorch.py --exp_name + + # On worker nodes (node 1, 2, ...): + torchrun --nnodes= --nproc_per_node= --rdzv_id= --rdzv_backend=c10d --rdzv_endpoint=: scripts/train_pytorch.py --exp_name + + Example (2 nodes, 4 GPUs each): + # Master node (192.168.1.100): + torchrun --nnodes=2 --nproc_per_node=4 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=192.168.1.100:29400 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_multi_node + + # Worker node (192.168.1.101): + torchrun --nnodes=2 --nproc_per_node=4 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=192.168.1.100:29400 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_multi_node +Multi-Node Setup Requirements: +1. Network connectivity: All nodes must be able to communicate on the specified port +2. Shared filesystem: All nodes must have access to the same dataset and checkpoint directories +3. Environment consistency: Same Python environment and dependencies on all nodes +4. Firewall configuration: Ensure the rendezvous port (e.g., 29400) is open between nodes +5. SSH access: Nodes should be able to SSH to each other (for torchrun coordination) +Environment Variables for Multi-Node: +- MASTER_ADDR: IP address of the master node (auto-set by torchrun) +- MASTER_PORT: Port for rendezvous (auto-set by torchrun) +- WORLD_SIZE: Total number of processes across all nodes +- RANK: Global rank of the process (0 to WORLD_SIZE-1) +- LOCAL_RANK: Local rank within the node (0 to nproc_per_node-1) +- NODE_RANK: Rank of the node (0 to nnodes-1) +Checkpoint Parameters: +- --ckpt_save_interval: Override the checkpoint save interval from config (e.g., --save_interval 500) +- --resume: Resume training from the latest checkpoint in the checkpoint directory +- --overwrite: Overwrite existing checkpoint directory (cannot be used with --resume) +Memory Optimization Parameters: +- --gradient_accumulation_steps: Number of steps to accumulate gradients (default: 1) +- --mixed_precision: Enable mixed precision training (default: True) +- --max_memory_usage: Maximum GPU memory usage in GB (default: None, auto-detect) +Notes +- The global batch size must be divisible by world size (number of processes). +- The data pipeline and transforms are identical to the JAX version and are controlled + by the selected TrainConfig (e.g., `LeRobot*` configs for real datasets or `FakeDataConfig`). +- Supports Weights & Biases (wandb) logging for experiment tracking and visualization. +- Checkpoints include model state, optimizer state, and training metadata for complete resume capability. +- For optimal multi-node performance, ensure high-bandwidth network connectivity (e.g., InfiniBand). +- Monitor GPU utilization and network bandwidth during multi-node training. +- Memory optimizations can significantly reduce GPU memory usage while maintaining training quality. +""" +import argparse +import dataclasses +import logging +import os +import platform +import time +from dataclasses import dataclass +from typing import Any, Dict, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader as TorchDataLoader +from torch.utils.data.distributed import DistributedSampler +import wandb +from tqdm import tqdm + +import openpi.training.config as _config +import openpi.training.data_loader as _data +import openpi.models.model as _model +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch +from openpi.models.pi0_config import Pi0Config + + +def init_logging(): + level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} + + class CustomFormatter(logging.Formatter): + def format(self, record): + record.levelname = level_mapping.get(record.levelname, record.levelname) + return super().format(record) + + formatter = CustomFormatter( + fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", + datefmt="%H:%M:%S", + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + if not logger.handlers: + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + else: + logger.handlers[0].setFormatter(formatter) + + +def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): + """Initialize wandb logging.""" + if not enabled: + wandb.init(mode="disabled") + return + + ckpt_dir = config.checkpoint_dir + if not ckpt_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") + + if resuming: + run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() + wandb.init(id=run_id, resume="must", project=config.project_name) + else: + wandb.init( + name=config.exp_name, + config=dataclasses.asdict(config), + project=config.project_name, + ) + (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) + + +def setup_ddp(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + use_ddp = world_size > 1 + if use_ddp and not dist.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method="env://") + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(device) + return use_ddp, local_rank, device + + +def cleanup_ddp(): + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +def set_seed(seed: int, local_rank: int): + torch.manual_seed(seed + local_rank) + np.random.seed(seed + local_rank) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed + local_rank) + + +def build_datasets(config: _config.TrainConfig): + # Reuse existing dataset + transforms pipeline + data_conf = config.data.create(config.assets_dirs, config.model) + dataset = _data.create_torch_dataset(data_conf, config.model.action_horizon, config.model) + print(f"data_conf: {data_conf}") + dataset = _data.transform_dataset(dataset, data_conf) + return dataset, data_conf + + +def collate_to_numpy(batch_list: list[Dict[str, Any]]) -> Dict[str, Any]: + # Recursively stack leaves with numpy + def stack_leaf(*xs): + return np.stack([np.asarray(x) for x in xs], axis=0) + + # Memory-efficient collation + result = torch.utils.data.default_collate(batch_list) if not isinstance(batch_list[0], dict) else _tree_map_multi(stack_leaf, batch_list) + + # Clear batch list from memory + del batch_list + + return result + + +def _tree_map_multi(func, batch_list): + # batch_list is a list of dicts with same structure; reduce by zipping leaves + def recurse(keys, items): + if isinstance(items[0], dict): + return {k: recurse(keys + [k], [it[k] for it in items]) for k in items[0].keys()} + return func(*items) + return recurse([], batch_list) + + +def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + # Maintain canonical image key order + image_keys = _model.IMAGE_KEYS + import jax + + # Memory-efficient conversion: convert to torch tensors and move to device in one step + batch = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(device), batch) + + # Convert to float32 for memory efficiency (avoid float64) + batch['state'] = batch['state'].to(dtype=torch.float32) + batch['actions'] = batch['actions'].to(dtype=torch.float32) + + # Clear numpy arrays from memory if they exist + del jax + + return batch + + +def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_interval=None, ema_model=None): + """Save a checkpoint with model state, optimizer state, EMA state, and metadata.""" + if not is_main: + return + + # Use ckpt_save_interval if provided, otherwise use config.save_interval + save_interval = ckpt_save_interval if ckpt_save_interval is not None else config.save_interval + + # Only save if it's time to save or if it's the final step + if (global_step % save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: + ckpt_dir = os.path.join(config.checkpoint_dir, f"{global_step}") + os.makedirs(ckpt_dir, exist_ok=True) + + # Save model state + state_dict = (model.module if isinstance(model, DDP) else model).state_dict() + torch.save(state_dict, os.path.join(ckpt_dir, "pytorch_model.pt")) + + # Save optimizer state + torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer.pt")) + + # Save EMA state if available + if ema_model is not None: + torch.save(ema_model.state_dict(), os.path.join(ckpt_dir, "ema_model.pt")) + + # Save training metadata + metadata = { + "global_step": global_step, + "config": dataclasses.asdict(config), + "timestamp": time.time(), + } + torch.save(metadata, os.path.join(ckpt_dir, "metadata.pt")) + + logging.info(f"Saved checkpoint at step {global_step} -> {ckpt_dir}") + + # Log checkpoint to wandb + if config.wandb_enabled: + wandb.log({"checkpoint_step": global_step}, step=global_step) + + +def load_checkpoint(model, optimizer, config, device, ema_model=None): + """Load the latest checkpoint and return the global step.""" + checkpoint_steps = [] + for d in config.checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + if not checkpoint_steps: + raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir}") + + latest_step = max(checkpoint_steps) + ckpt_dir = os.path.join(config.checkpoint_dir, f"{latest_step}") + + # Load model state + model_state_dict = torch.load(os.path.join(ckpt_dir, "pytorch_model.pt"), map_location=device) + (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) + + # Load optimizer state + optimizer_state_dict = torch.load(os.path.join(ckpt_dir, "optimizer.pt"), map_location=device) + optimizer.load_state_dict(optimizer_state_dict) + + # Load EMA state if available + if ema_model is not None and os.path.exists(os.path.join(ckpt_dir, "ema_model.pt")): + ema_state_dict = torch.load(os.path.join(ckpt_dir, "ema_model.pt"), map_location=device) + ema_model.load_state_dict(ema_state_dict) + logging.info(f"Loaded EMA state from checkpoint") + + # Load metadata + metadata = torch.load(os.path.join(ckpt_dir, "metadata.pt"), map_location=device) + + logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") + return metadata["global_step"] + + +def get_latest_checkpoint_step(config): + """Get the latest checkpoint step number.""" + checkpoint_steps = [] + for d in config.checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + return max(checkpoint_steps) if checkpoint_steps else None + + +def setup_memory_optimizations(model, device, enable_gradient_checkpointing=False): + """Setup memory optimization techniques for the model.""" + if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): + model.gradient_checkpointing_enable() + logging.info("Enabled gradient checkpointing for memory optimization") + + # Enable memory efficient attention if available + if hasattr(model, 'config') and hasattr(model.config, 'attention_mode'): + model.config.attention_mode = 'flash_attention_2' + logging.info("Enabled Flash Attention 2 for memory efficiency") + + # Set memory efficient settings + if torch.cuda.is_available(): + # Enable memory efficient algorithms + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + # Set memory fraction if needed + if device.index is not None: + torch.cuda.empty_cache() + logging.info(f"Cleared CUDA cache for device {device.index}") + + +def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False): + use_ddp, local_rank, device = setup_ddp() + is_main = (not use_ddp) or (dist.get_rank() == 0) + set_seed(config.seed, local_rank) + + # Memory optimization: Set memory fraction if specified + if max_memory_usage is not None and torch.cuda.is_available(): + torch.cuda.set_per_process_memory_fraction(max_memory_usage / torch.cuda.get_device_properties(device).total_memory * 1e-9) + + # Initialize checkpoint directory and wandb + resuming = False + if config.resume: + # Check if checkpoint directory exists and has checkpoints + if config.checkpoint_dir.exists(): + latest_step = get_latest_checkpoint_step(config) + if latest_step is not None: + resuming = True + logging.info(f"Resuming from checkpoint directory: {config.checkpoint_dir} at step {latest_step}") + else: + raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir} for resume") + else: + raise FileNotFoundError(f"Checkpoint directory {config.checkpoint_dir} does not exist for resume") + elif config.overwrite and config.checkpoint_dir.exists(): + import shutil + shutil.rmtree(config.checkpoint_dir) + logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") + + # Create checkpoint directory + config.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Initialize wandb (only on main process) + if is_main: + init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) + + # Build dataset + sampler + loader + dataset, data_conf = build_datasets(config) + sampler = None + if use_ddp: + sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True, drop_last=True) + + # Reduce batch size for gradient accumulation + effective_batch_size = config.batch_size // (dist.get_world_size() if use_ddp else 1) + + # Memory-efficient data loading with reduced pin_memory for large datasets + pin_memory = True + if effective_batch_size > 16: # Reduce pin_memory for large batches + pin_memory = False + logging.info("Disabled pin_memory for large batch size to reduce memory usage") + + loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) + + # Log sample images to wandb on first batch + if is_main and config.wandb_enabled and not resuming: + sample_batch = next(iter(loader)) + sample_batch = batch_to_torch(sample_batch, device) + + # Create sample images for wandb + images_to_log = [] + # Get batch size from the first image tensor + batch_size = next(iter(sample_batch['image'].values())).shape[0] + for i in range(min(5, batch_size)): + # Concatenate all camera views horizontally for this batch item + img_concatenated = torch.cat([img[i] for img in sample_batch['image'].values()], axis=1) + img_concatenated = img_concatenated.cpu().numpy() + images_to_log.append(wandb.Image(img_concatenated)) + + wandb.log({"camera_views": images_to_log}, step=0) + + # Clear sample batch from memory + del sample_batch, images_to_log + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Reset the loader iterator + loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) + + # Build model + if not isinstance(config.model, Pi0Config): + # Convert dataclass to Pi0Config if needed + model_cfg = Pi0Config( + action_dim=config.model.action_dim, + action_horizon=config.model.action_horizon, + max_token_len=config.model.max_token_len, + paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), + action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), + pi05=getattr(config.model, "pi05", False), + ) + else: + model_cfg = config.model + + model = PI0Pytorch(model_cfg).to(device) + + # Apply memory optimizations + setup_memory_optimizations(model, device, enable_gradient_checkpointing) + + if use_ddp: + model = DDP(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=False) + + # Load weights from weight_loader if specified (for fine-tuning) + if isinstance(config.weight_loader, str): + weight_path = config.weight_loader + logging.info(f"Loading weights from: {weight_path}") + + model_path = os.path.join(weight_path, "model.safetensors") + from safetensors.torch import load_model + load_model((model.module if isinstance(model, DDP) else model), model_path) + logging.info(f"Loaded PyTorch weights from {weight_path}") + + # Optimizer + learning rate schedule from config + warmup_steps = config.lr_schedule.warmup_steps + peak_lr = config.lr_schedule.peak_lr + decay_steps = config.lr_schedule.decay_steps + end_lr = config.lr_schedule.decay_lr + + # Create optimizer with config parameters + optim = torch.optim.AdamW( + model.parameters(), + lr=peak_lr, + betas=(config.optimizer.b1, config.optimizer.b2), + eps=config.optimizer.eps, + weight_decay=config.optimizer.weight_decay + ) + + # Initialize EMA if specified in config + ema_model = None + if config.ema_decay is not None: + ema_model = PI0Pytorch(model_cfg).to(device) + ema_model.load_state_dict(model.state_dict()) + ema_model.eval() + logging.info(f"Initialized EMA with decay {config.ema_decay}") + + # Load checkpoint if resuming + global_step = 0 + if resuming: + global_step = load_checkpoint(model, optim, config, device, ema_model) + logging.info(f"Resumed training from step {global_step}") + + def lr_schedule(step: int): + if step < warmup_steps: + return peak_lr * (step + 1) / warmup_steps + # cosine decay + progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) + cos = 0.5 * (1 + np.cos(np.pi * progress)) + return end_lr + (peak_lr - end_lr) * cos + + # Enable mixed precision training for memory optimization + scaler = torch.amp.GradScaler(enabled=mixed_precision and torch.cuda.is_available()) + + model.train() + start_time = time.time() + infos = [] # Collect stats over log interval + if is_main: + logging.info(f"Running on: {platform.node()} | world_size={dist.get_world_size() if use_ddp else 1}") + logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}") + logging.info(f"Memory optimizations: gradient_accumulation_steps={gradient_accumulation_steps}, mixed_precision={mixed_precision}, gradient_checkpointing={enable_gradient_checkpointing}") + logging.info(f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}") + logging.info(f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}") + if config.ema_decay is not None: + logging.info(f"EMA decay: {config.ema_decay}") + + # Training loop - iterate until we reach num_train_steps + pbar = tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None + + while global_step < config.num_train_steps: + if use_ddp: + sampler.set_epoch(global_step // len(loader)) + + for batch in loader: + # Check if we've reached the target number of steps + if global_step >= config.num_train_steps: + break + + # Convert dict batch directly to torch tensors (bypass Observation.from_dict for PyTorch) + batch = batch_to_torch(batch, device) + actions = batch["actions"] + + # Update LR + for pg in optim.param_groups: + pg["lr"] = lr_schedule(global_step) + + # Forward pass with mixed precision + observation = _model.Observation.from_dict(batch) + with torch.amp.autocast('cuda', enabled=mixed_precision and torch.cuda.is_available()): + losses = model(observation, actions) + loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation + + # Backward pass with gradient scaling + scaler.scale(loss).backward() + + # Gradient accumulation logic + if (global_step + 1) % gradient_accumulation_steps == 0: + # Unscale gradients for clipping + scaler.unscale_(optim) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) + + # Optimizer step + scaler.step(optim) + scaler.update() + optim.zero_grad(set_to_none=True) + + # Update EMA if enabled + if ema_model is not None: + with torch.no_grad(): + for param, ema_param in zip(model.parameters(), ema_model.parameters()): + ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) + + # Collect stats (only on accumulation steps) + if (global_step + 1) % gradient_accumulation_steps == 0 and is_main: + infos.append({ + "loss": loss.item() * gradient_accumulation_steps, # Unscale for logging + "learning_rate": optim.param_groups[0]['lr'], + }) + + if is_main and (global_step % config.log_interval == 0) and (global_step + 1) % gradient_accumulation_steps == 0: + elapsed = time.time() - start_time + + # Average stats over log interval + avg_loss = sum(info["loss"] for info in infos) / len(infos) + avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) + + logging.info(f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s") + + # Log to wandb + if config.wandb_enabled: + wandb.log({ + "loss": avg_loss, + "learning_rate": avg_lr, + "step": global_step, + "time_per_step": elapsed / config.log_interval, + }, step=global_step) + + start_time = time.time() + infos = [] # Reset stats collection + + # Save checkpoint using the new mechanism + save_checkpoint(model, optim, global_step, config, is_main, ckpt_save_interval, ema_model) + + global_step += 1 + + # Update progress bar + if pbar is not None: + pbar.update(1) + pbar.set_postfix({ + 'loss': f'{loss.item() * gradient_accumulation_steps:.4f}', + 'lr': f'{optim.param_groups[0]["lr"]:.2e}', + 'step': global_step + }) + + # Memory cleanup after each batch + del batch, actions, observation, losses, loss + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Close progress bar + if pbar is not None: + pbar.close() + + # Finish wandb run + if is_main and config.wandb_enabled: + wandb.finish() + + cleanup_ddp() + + +def main(): + init_logging() + config = _config.cli() + + # Parse additional command line arguments for memory optimization + import argparse + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--ckpt_save_interval", type=int, default=None, + help="Interval for saving checkpoints (overrides config.save_interval)") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of steps to accumulate gradients (default: 1)") + parser.add_argument("--mixed_precision", action="store_true", default=False, + help="Enable mixed precision training (default: True)") + parser.add_argument("--no_mixed_precision", action="store_true", default=True, + help="Disable mixed precision training") + parser.add_argument("--max_memory_usage", type=float, default=None, + help="Maximum GPU memory usage in GB (default: None, auto-detect)") + parser.add_argument("--enable_gradient_checkpointing", action="store_true", default=False, + help="Enable gradient checkpointing for memory optimization") + args, _ = parser.parse_known_args() + + # Handle mixed precision flag + mixed_precision = args.mixed_precision and not args.no_mixed_precision + + train_loop(config, + ckpt_save_interval=args.ckpt_save_interval, + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=mixed_precision, + max_memory_usage=args.max_memory_usage, + enable_gradient_checkpointing=args.enable_gradient_checkpointing) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_single_example.py b/scripts/train_single_example.py new file mode 100644 index 0000000..2a23d48 --- /dev/null +++ b/scripts/train_single_example.py @@ -0,0 +1,548 @@ +""" +Train on a single example for debugging JAX vs PyTorch comparison. +This script creates a deterministic dataset with one example and trains on it +to help debug differences between JAX and PyTorch implementations. +""" + +import logging +import numpy as np +import torch +import jax +import jax.numpy as jnp +import flax.nnx as nnx +import flax + +from openpi.models import model as _model +from openpi.models.pi0_config import Pi0Config +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch + + +def setup_logging(): + """Setup logging for debugging.""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' + ) + + +def create_fixed_example(): + """Create a fixed example for debugging.""" + np.random.seed(42) + + batch_size = 1 + action_dim = 32 + action_horizon = 10 + image_size = 224 + max_token_len = 48 + + # Create fixed images + images = {} + for key in _model.IMAGE_KEYS: + img = np.zeros((batch_size, image_size, image_size, 3), dtype=np.float32) + + # Simple gradient pattern + for i in range(image_size): + for j in range(image_size): + val = (i + j) / (2 * image_size) * 2 - 1 + img[0, i, j, :] = [val, val * 0.5, val * 0.25] + + images[key] = img + + # Create fixed state and actions + state = np.random.randn(batch_size, action_dim).astype(np.float32) * 0.1 + actions = np.random.randn(batch_size, action_horizon, action_dim).astype(np.float32) * 0.1 + + # Create fixed language tokens + tokenized_prompt = np.random.randint(0, 1000, (batch_size, max_token_len), dtype=np.int32) + tokenized_prompt_mask = np.ones((batch_size, max_token_len), dtype=bool) + + # Create image masks + image_masks = {key: np.ones(batch_size, dtype=bool) for key in _model.IMAGE_KEYS} + + return { + "image": images, + "image_mask": image_masks, + "state": state, + "actions": actions, + "tokenized_prompt": tokenized_prompt, + "tokenized_prompt_mask": tokenized_prompt_mask, + } + + +def create_fixed_noise_and_time(batch_size, action_horizon, action_dim): + """Create fixed noise and time values for deterministic comparison.""" + np.random.seed(42) # Use same seed for consistency + + # Create fixed noise + noise = np.random.randn(batch_size, action_horizon, action_dim).astype(np.float32) * 0.1 + + # Create fixed time values (beta distribution like in the models) + time_beta = np.random.beta(1.5, 1.0, batch_size).astype(np.float32) + time = time_beta * 0.999 + 0.001 + + return noise, time + + +def test_pytorch_single_example(noise, time): + """Test PyTorch training on single example.""" + print("\n=== Testing PyTorch on Single Example ===") + + # Create model + config = Pi0Config(action_dim=32, action_horizon=10, pi05=True) + model = PI0Pytorch(config) + + # Load pre-trained weights + weight_path = "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2/model.safetensors" + print(f"Loading PyTorch weights from: {weight_path}") + + from safetensors.torch import load_model + load_model(model, weight_path) + + # Create fixed example + example = create_fixed_example() + + # Convert to PyTorch tensors + pytorch_example = {} + for key, value in example.items(): + if key == "image": + # Convert channels-last [B, H, W, C] to channels-first [B, C, H, W] for PyTorch + pytorch_example[key] = {} + for k, v in value.items(): + # v is [B, H, W, C], convert to [B, C, H, W] + v_tensor = torch.from_numpy(v) + v_tensor = v_tensor.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + pytorch_example[key][k] = v_tensor + elif key == "image_mask": + pytorch_example[key] = {k: torch.from_numpy(v) for k, v in value.items()} + else: + pytorch_example[key] = torch.from_numpy(value) + + # Convert noise and time to PyTorch tensors + noise_tensor = torch.from_numpy(noise) + time_tensor = torch.from_numpy(time) + + # Create observation + observation = _model.Observation.from_dict(pytorch_example) + actions = pytorch_example["actions"] + + print(f"Observation state shape: {observation.state.shape}") + print(f"Observation state dtype: {observation.state.dtype}") + print(f"Actions shape: {actions.shape}") + print(f"Actions dtype: {actions.dtype}") + print(f"Noise shape: {noise_tensor.shape}, dtype: {noise_tensor.dtype}") + print(f"Time shape: {time_tensor.shape}, dtype: {time_tensor.dtype}") + + # Test forward pass with fixed noise and time + model.eval() + with torch.no_grad(): + #try: + losses = model(observation, actions, noise=noise_tensor, time=time_tensor) + print(f"PyTorch forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + # mean_loss = losses.to(torch.float32).mean().item() + # print(f"Mean loss: {mean_loss:.6f}") + return True, losses + # except Exception as e: + # print(f"PyTorch forward pass failed: {e}") + # return False, None + + +def test_jax_single_example(noise, time, debug_single_layer=False): + """Test JAX training on single example.""" + print("\n=== Testing JAX on Single Example ===") + + # Create model + config = Pi0Config(action_dim=32, action_horizon=10, pi05=True) + if debug_single_layer: + print("šŸ”§ Debug mode: Using only 1 encoder layer") + + # Create a custom model with modified siglip depth for debugging + if debug_single_layer: + # Import the Pi0 model class + from openpi.models.pi0 import Pi0 + import openpi.models.gemma as _gemma + import openpi.models.siglip as _siglip + import flax.nnx.bridge as nnx_bridge + + # Create the model manually with custom siglip variant + rng = jax.random.key(42) + rngs = flax.nnx.Rngs(rng) + + paligemma_config = _gemma.get_config(config.paligemma_variant) + action_expert_config = _gemma.get_config(config.action_expert_variant) + + # Create LLM + llm = nnx_bridge.ToNNX( + _gemma.Module( + configs=[paligemma_config, action_expert_config], + embed_dtype=config.dtype, + adarms=config.pi05, + ) + ) + llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False]) + + # Create custom siglip model with depth=1 + # We'll use the same variant but override the depth parameter + siglip_params = _siglip.decode_variant("So400m/14") + siglip_params["depth"] = 1 # Override depth to 1 for debugging + + img = nnx_bridge.ToNNX( + _siglip.Module( + num_classes=paligemma_config.width, + variant=None, # Don't use variant, use explicit params + pool_type="none", + scan=False, # Disable scan for single layer + dtype_mm=config.dtype, + **siglip_params, # Pass the modified parameters + ) + ) + img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) + + # Create the full model + model = Pi0(config, rngs) + # Replace the siglip model with our custom one + model.PaliGemma.img = img + + print("šŸ”§ Created single-layer SigLIP model (depth=1) for debugging...") + else: + rng = jax.random.key(42) + model = config.create(rng) + + # Load pre-trained weights + weight_path = "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" + print(f"Loading JAX weights from: {weight_path}") + + # try: + # Use the same approach as in policy_config.py + params = _model.restore_params(weight_path, dtype=jnp.bfloat16) + + # Filter params to only include the first encoder layer for debugging + if debug_single_layer: + filtered_params = {} + + # The parameters are nested, so we need to traverse the structure + def filter_nested_params(params_dict, key_path=""): + result = {} + for key, value in params_dict.items(): + current_path = f"{key_path}.{key}" if key_path else key + + if isinstance(value, dict): + # Recursive case - traverse deeper + filtered_sub = filter_nested_params(value, current_path) + if filtered_sub: # Only include if there are sub-parameters + result[key] = filtered_sub + else: + # Leaf case - check if this parameter should be included + if 'Transformer' in current_path: + # Only keep the first encoder block (encoderblock_0) and encoder_norm + if 'encoderblock_0' in current_path or 'encoder_norm' in current_path: + result[key] = value + else: + # Keep all non-Transformer params + result[key] = value + return result + + filtered_params = filter_nested_params(params) + params_to_use = filtered_params + print("āœ… JAX weights loaded successfully (first layer only)!") + print("āš ļø Note: Using So400m variant with depth=1 (modified from depth=27)") + print("āš ļø Only the first layer weights will be used, others will be randomly initialized") + + # Debug: Show what parameters we have + print(f"šŸ“‹ Available parameters for single-layer model:") + transformer_params = [] + for key in sorted(params_to_use.keys()): + if 'Transformer' in key: + transformer_params.append(key) + print(f" {key}: {params_to_use[key].shape}") + + if not transformer_params: + print(" No Transformer parameters found! Let's see all keys:") + for key in sorted(params_to_use.keys())[:20]: # Show first 20 keys + print(f" {key}") + + # Let's also check if PaliGemma has nested structure + if 'PaliGemma' in params_to_use: + print(" Checking PaliGemma structure:") + paligemma_params = params_to_use['PaliGemma'] + if hasattr(paligemma_params, 'keys'): + for subkey in sorted(paligemma_params.keys()): + print(f" PaliGemma.{subkey}") + if hasattr(paligemma_params[subkey], 'keys'): + for subsubkey in sorted(paligemma_params[subkey].keys()): + print(f" PaliGemma.{subkey}.{subsubkey}") + if hasattr(paligemma_params[subkey][subsubkey], 'keys'): + for subsubsubkey in sorted(paligemma_params[subkey][subsubkey].keys()): + if 'Transformer' in subsubsubkey: + print(f" PaliGemma.{subkey}.{subsubkey}.{subsubsubkey}") + + # The issue is that with scan=False, the model expects different parameter names + # We need to map from encoderblock_0 to encoderblock in the nested structure + def adapt_nested_params(params_dict, key_path=""): + result = {} + for key, value in params_dict.items(): + current_path = f"{key_path}.{key}" if key_path else key + + if isinstance(value, dict): + # Recursive case - traverse deeper + result[key] = adapt_nested_params(value, current_path) + else: + # Leaf case - adapt the key if needed + new_key = key + if 'Transformer' in current_path and 'encoderblock_0' in key: + # Map encoderblock_0 to encoderblock for non-scan mode + new_key = key.replace('encoderblock_0', 'encoderblock') + result[new_key] = value + return result + + adapted_params = adapt_nested_params(params_to_use) + params_to_use = adapted_params + print("šŸ”„ Adapted parameter names for non-scan mode") + print(f" Example mapping: encoderblock_0 -> encoderblock") + else: + params_to_use = params + print("āœ… JAX weights loaded successfully!") + + # Apply the params to the model using NNX state management + import flax.nnx as nnx + graphdef, model_state = nnx.split(model) + + # Debug: Let me check what the model actually expects first + print(f"šŸ” Checking what the model expects...") + try: + print(f"šŸ“‹ Model parameter structure:") + model_transformer_params = [] + for key in sorted(model_state.keys()): + if 'Transformer' in key: + model_transformer_params.append(key) + print(f" {key}: shape {getattr(model_state[key], 'shape', 'no shape')}") + + if not model_transformer_params: + print(" No Transformer parameters found in model! Let's see all keys:") + for key in sorted(model_state.keys())[:20]: # Show first 20 keys + print(f" {key}") + + # Let's also check if PaliGemma has nested structure in model + if 'PaliGemma' in model_state: + print(" Checking PaliGemma structure in model:") + paligemma_state = model_state['PaliGemma'] + if hasattr(paligemma_state, 'keys'): + for subkey in sorted(paligemma_state.keys()): + print(f" PaliGemma.{subkey}") + if hasattr(paligemma_state[subkey], 'keys'): + for subsubkey in sorted(paligemma_state[subkey].keys()): + print(f" PaliGemma.{subkey}.{subsubkey}") + if hasattr(paligemma_state[subkey][subsubkey], 'keys'): + for subsubsubkey in sorted(paligemma_state[subkey][subsubkey].keys()): + if 'Transformer' in subsubsubkey: + print(f" PaliGemma.{subkey}.{subsubkey}.{subsubsubkey}") + except Exception as e: + print(f" Could not inspect model parameters: {e}") + + # Now try to load parameters + try: + model_state.replace_by_pure_dict(params_to_use) + model = nnx.merge(graphdef, model_state) + print("āœ… Parameters loaded successfully!") + except Exception as e: + print(f"āŒ Parameter loading failed: {e}") + print("šŸ”„ Continuing with random initialization...") + model = nnx.merge(graphdef, model_state) + # except Exception as e: + # print(f"āŒ Failed to load JAX weights: {e}") + # print("Continuing with random initialization...") + + # Create fixed example + example = create_fixed_example() + + # Convert to JAX arrays + jax_example = {} + for key, value in example.items(): + if key == "image": + jax_example[key] = {k: jnp.array(v) for k, v in value.items()} + elif key == "image_mask": + jax_example[key] = {k: jnp.array(v) for k, v in value.items()} + else: + jax_example[key] = jnp.array(value) + + # Convert noise and time to JAX arrays + noise_jax = jnp.array(noise) + time_jax = jnp.array(time) + + # Create observation + observation = _model.Observation.from_dict(jax_example) + actions = jax_example["actions"] + + print(f"Observation state shape: {observation.state.shape}") + print(f"Observation state dtype: {observation.state.dtype}") + print(f"Actions shape: {actions.shape}") + print(f"Actions dtype: {actions.dtype}") + print(f"Noise shape: {noise_jax.shape}, dtype: {noise_jax.dtype}") + print(f"Time shape: {time_jax.shape}, dtype: {time_jax.dtype}") + + # Test forward pass with fixed noise and time + # try: + # Use the modified compute_loss method that accepts external noise and time + losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) + print(f"JAX forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + mean_loss = jnp.mean(losses).item() + print(f"Mean loss: {mean_loss:.6f}") + return True, losses + # except Exception as e: + # print(f"JAX forward pass failed: {e}") + # return False, None + + +def compare_losses(pytorch_loss, jax_loss): + """Compare losses and compute relative differences.""" + if pytorch_loss is None or jax_loss is None: + return + + print("\n" + "=" * 70) + print("šŸ“Š LOSS COMPARISON") + print("=" * 70) + + # # Handle tensor inputs by computing mean if needed + # if hasattr(pytorch_loss, 'mean'): + # pytorch_mean = pytorch_loss.to(torch.float32).mean().item() + # pytorch_std = pytorch_loss.to(torch.float32).std().item() + # print(f"PyTorch loss tensor - Mean: {pytorch_mean:.8f}, Std: {pytorch_std:.8f}") + # print(f"PyTorch loss shape: {pytorch_loss.shape}") + # else: + # pytorch_mean = float(pytorch_loss) + # pytorch_std = 0.0 + # print(f"PyTorch loss scalar: {pytorch_mean:.8f}") + + # if hasattr(jax_loss, 'mean'): + # jax_mean = jax_loss.mean().item() + # jax_std = jax_loss.std().item() + # print(f"JAX loss tensor - Mean: {jax_mean:.8f}, Std: {jax_std:.8f}") + # print(f"JAX loss shape: {jax_loss.shape}") + # else: + # jax_mean = float(jax_loss) + # jax_std = 0.0 + # print(f"JAX loss scalar: {jax_mean:.8f}") + + + + # Additional tensor analysis if both are tensors + pytorch_loss = pytorch_loss.to(torch.float32) + jax_loss = jax_loss.astype(jnp.float32) + if hasattr(pytorch_loss, 'shape') and hasattr(jax_loss, 'shape'): + print(f"\nšŸ“ Tensor Analysis:") + + # Check if shapes match + if pytorch_loss.shape == jax_loss.shape: + print(f"āœ… Tensor shapes match: {pytorch_loss.shape}") + + # Element-wise comparison + if hasattr(pytorch_loss, 'flatten') and hasattr(jax_loss, 'flatten'): + # Convert to numpy for element-wise analysis + try: + pytorch_flat = pytorch_loss.detach().cpu().numpy().flatten() + jax_flat = jax_loss.flatten() + + # Element-wise differences + element_diff = np.abs(pytorch_flat - jax_flat) + print(f"element_diff[0]: {element_diff[0:2048*816:2048]}") + max_element_diff = np.max(element_diff) + mean_element_diff = np.mean(element_diff) + + print(f" Max element-wise difference: {max_element_diff:.8f}") + print(f" Mean element-wise difference: {mean_element_diff:.8f}") + + # Element-wise relative differences + # Avoid division by zero by adding small epsilon + epsilon = 1e-12 + pytorch_flat_safe = pytorch_flat + epsilon + jax_flat_safe = jax_flat + epsilon + + # Compute relative differences for each element + rel_diff_pytorch_elements = (element_diff / np.abs(pytorch_flat_safe)) * 100 + rel_diff_jax_elements = (element_diff / np.abs(jax_flat_safe)) * 100 + + # Compute mean of relative differences + mean_rel_diff_pytorch = np.mean(rel_diff_pytorch_elements) + mean_rel_diff_jax = np.mean(rel_diff_jax_elements) + + print(f" Mean relative difference (w.r.t. PyTorch elements): {mean_rel_diff_pytorch:.4f}%") + print(f" Mean relative difference (w.r.t. JAX elements): {mean_rel_diff_jax:.4f}%") + + # Count elements with significant differences + significant_threshold = 1e-4 + significant_count = np.sum(element_diff > significant_threshold) + total_elements = len(element_diff) + significant_percentage = (significant_count / total_elements) * 100 + + print(f" Elements with diff > {significant_threshold}: {significant_count}/{total_elements} ({significant_percentage:.2f}%)") + + # Additional relative difference analysis + significant_rel_threshold = 1.0 # 1% + significant_rel_count_pytorch = np.sum(rel_diff_pytorch_elements > significant_rel_threshold) + significant_rel_count_jax = np.sum(rel_diff_jax_elements > significant_rel_threshold) + + print(f" Elements with rel diff > {significant_rel_threshold}% (w.r.t. PyTorch): {significant_rel_count_pytorch}/{total_elements} ({(significant_rel_count_pytorch/total_elements)*100:.2f}%)") + print(f" Elements with rel diff > {significant_rel_threshold}% (w.r.t. JAX): {significant_rel_count_jax}/{total_elements} ({(significant_rel_count_jax/total_elements)*100:.2f}%)") + + except Exception as e: + print(f" āš ļø Could not perform element-wise analysis: {e}") + else: + print(f"āŒ Tensor shapes don't match: PyTorch {pytorch_loss.shape} vs JAX {jax_loss.shape}") + + +def main(): + """Main function to test both implementations.""" + setup_logging() + + print("šŸš€ Testing Single Example Training for JAX vs PyTorch Comparison") + print("=" * 70) + print("šŸ“ Loading pre-trained weights for both models...") + print("šŸŽÆ Using fixed noise and time values for deterministic comparison...") + print("šŸ”§ Debug mode: JAX model will use only 1 encoder layer for faster debugging...") + + # Generate fixed noise and time + noise, time = create_fixed_noise_and_time( + batch_size=1, + action_horizon=10, + action_dim=32 + ) + + # Test PyTorch + pytorch_success, pytorch_losses = test_pytorch_single_example(noise, time) + torch.cuda.empty_cache() + + # Test JAX + jax_success, jax_losses = test_jax_single_example(noise, time, debug_single_layer=False) + + # Compare losses + if pytorch_success and jax_success: + compare_losses(pytorch_losses, jax_losses) + + # Summary + print("\n" + "=" * 70) + print("šŸ“Š SUMMARY") + print("=" * 70) + + if pytorch_success and jax_success: + print("āœ… Both JAX and PyTorch implementations work on the single example!") + print("šŸ” Loss comparison completed above.") + elif pytorch_success: + print("āŒ PyTorch works but JAX failed. Check JAX implementation.") + elif jax_success: + print("āŒ JAX works but PyTorch failed. Check PyTorch implementation.") + else: + print("āŒ Both implementations failed. Check the error messages above.") + + print("\nšŸ’” Next steps:") + print("1. Run this script to verify both implementations work") + print("2. Analyze the loss comparison results above") + print("3. If losses differ significantly, investigate the differences") + print("4. Check if the noise and time handling is consistent between implementations") + print("5. Use the same example in full training runs") + + +if __name__ == "__main__": + main() diff --git a/src/openpi/models/pi0.py b/src/openpi/models/pi0.py index ae7c459..c84f5c1 100644 --- a/src/openpi/models/pi0.py +++ b/src/openpi/models/pi0.py @@ -187,14 +187,24 @@ def embed_suffix( @override def compute_loss( - self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False + self, + rng: at.KeyArrayLike, + observation: _model.Observation, + actions: _model.Actions, + *, + train: bool = False, + noise: at.Float[at.Array, "*b ah ad"] | None = None, + time: at.Float[at.Array, "*b"] | None = None ) -> at.Float[at.Array, "*b ah"]: preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) - observation = _model.preprocess_observation(preprocess_rng, observation, train=train) + #observation = _model.preprocess_observation(preprocess_rng, observation, train=train) batch_shape = actions.shape[:-2] - noise = jax.random.normal(noise_rng, actions.shape) - time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 + # Use provided noise and time if available, otherwise generate them + if noise is None: + noise = jax.random.normal(noise_rng, actions.shape) + if time is None: + time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 time_expanded = time[..., None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions @@ -211,7 +221,7 @@ def compute_loss( ) v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) - return jnp.mean(jnp.square(v_t - u_t), axis=-1) + return jnp.square(v_t - u_t) @override def sample_actions( diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 6c6de37..9272a34 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -2,10 +2,38 @@ import torch from torch import nn from transformers import GemmaForCausalLM, PaliGemmaForConditionalGeneration +from transformers.models.gemma import modeling_gemma from transformers.models.auto import CONFIG_MAPPING +# TODO: compare this rope vs gemma rope +def apply_rope(x, positions, max_wavelength=10_000): + """ + Applies RoPE positions [B, L] to x [B, L, H, D]. + """ + d_half = x.shape[-1] // 2 + device = x.device + dtype = x.dtype + x = x.to(torch.float32) + + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + timescale = max_wavelength**freq_exponents + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + + radians = radians[..., None, :] + + sin = torch.sin(radians) # .to(dtype=dtype) + cos = torch.cos(radians) # .to(dtype=dtype) + + x1, x2 = x.split(d_half, dim=-1) + res = torch.empty_like(x) + res[..., :d_half] = x1 * cos - x2 * sin + res[..., d_half:] = x2 * cos + x1 * sin + + return res.to(dtype) + + class PaliGemmaWithExpertModel(nn.Module): def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): super().__init__() @@ -80,7 +108,7 @@ def forward( use_cache: bool | None = None, adarms_cond: list[torch.Tensor] | None = None, ): - if inputs_embeds[0] is not None: + if inputs_embeds[1] is None: prefix_output = self.paligemma.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, @@ -91,11 +119,8 @@ def forward( ) prefix_past_key_values = prefix_output.past_key_values prefix_output = prefix_output.last_hidden_state - else: - prefix_output = None - prefix_past_key_values = None - - if inputs_embeds[1] is not None: + suffix_output = None + if inputs_embeds[0] is None: suffix_output = self.gemma_expert.model.forward( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, @@ -105,7 +130,92 @@ def forward( adarms_cond=adarms_cond[1] if adarms_cond is not None else None, ) suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None else: - suffix_output = None + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + for layer_idx in range(num_layers): + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) + # hidden_states = hidden_states.to(dtype=torch.bfloat16) + # if gate is not None: + # gate = gate.to(dtype=torch.bfloat16) + gates.append(gate) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # B,L,H,D with L sequence length, H number of heads, D head dim + # concatenate on the number of embeddings/tokens + query_states = torch.cat(query_states, dim=1) + key_states = torch.cat(key_states, dim=1) + value_states = torch.cat(value_states, dim=1) + + query_states = apply_rope(query_states, position_ids) + key_states = apply_rope(key_states, position_ids) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) + # cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + # query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) + + batch_size = query_states.shape[0] + scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling + att_output, _ = modeling_gemma.eager_attention_forward( + self.paligemma.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling + ) + #att_output = att_output.to(dtype=torch.bfloat16) + att_output = att_output.reshape(batch_size, -1, 1 * 8 * layer.self_attn.head_dim) + + + # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) + outputs_embeds = [] + start = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end = start + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start:end]) + + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + out_emb = out_emb.to(dtype=torch.bfloat16) + + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) + outputs_embeds.append(out_emb) + start = end + inputs_embeds = outputs_embeds + + # final norm + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None return [prefix_output, suffix_output], prefix_past_key_values \ No newline at end of file diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 9a6c0ea..89bc0dd 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -230,8 +230,15 @@ def embed_suffix(self, state, noisy_actions, timestep): return embs, pad_masks, att_masks, adarms_cond - def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None) -> Tensor: + def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + # observation = _model.preprocess_observation_pytorch(observation, train=True) + images = list(observation.images.values()) + img_masks = list(observation.image_masks.values()) + lang_tokens = observation.tokenized_prompt + lang_masks = observation.tokenized_prompt_mask + state = observation.state + if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -244,6 +251,7 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) @@ -253,6 +261,7 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no # Add head dimension to attention mask: [B, seq_len, seq_len] -> [B, 1, seq_len, seq_len] att_2d_masks_4d = att_2d_masks[:, None, :, :] + att_2d_masks_4d = torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) (_, suffix_out), _ = self.paligemma_with_expert.forward( attention_mask=att_2d_masks_4d, @@ -260,7 +269,6 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no past_key_values=None, inputs_embeds=[prefix_embs, suffix_embs], use_cache=False, - fill_kv_cache=False, adarms_cond=[None, adarms_cond] ) suffix_out = suffix_out[:, -self.config.action_horizon :] @@ -268,12 +276,15 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no v_t = self.action_out_proj(suffix_out) - losses = F.mse_loss(u_t, v_t, reduction="none") + #losses = F.mse_loss(u_t, v_t, reduction="none") + + losses = torch.square(v_t - u_t) return losses @torch.no_grad() def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + # observation = _model.preprocess_observation(observation, train=False) bsize = observation.state.shape[0] if noise is None: actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index a4caa3c..723d212 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -724,7 +724,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=256, + batch_size=1, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -734,10 +734,30 @@ def __post_init__(self) -> None: optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader=weight_loaders.CheckpointWeightLoader( - "gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/params" + "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" ), num_train_steps=30_000, ), + TrainConfig( + name="pi05_libero_pytorch", + model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False), + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=False, + ), + batch_size=1, + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=10_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), + ema_decay=0.999, + weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", + num_train_steps=30_000, + ), # # Fine-tuning Aloha configs. # diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index 1f97f27..8355336 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -228,6 +228,7 @@ def create_data_loader( ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training.""" data_config = config.data.create(config.assets_dirs, config.model) + print(f"data_config: {data_config}") if data_config.rlds_data_dir is not None: return create_rlds_data_loader( diff --git a/uv.lock b/uv.lock index 1ad8e94..13bd8fc 100644 --- a/uv.lock +++ b/uv.lock @@ -1771,24 +1771,33 @@ wheels = [ [package.optional-dependencies] with-cuda = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", version = "12.9.0.13", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", version = "12.9.19", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "nvidia-cuda-nvcc-cu12" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-runtime-cu12", version = "12.9.37", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, { name = "nvidia-cudnn-cu12", version = "9.10.1.4", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cudnn-cu12", version = "9.10.2.21", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.3.83", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cufft-cu12", version = "11.4.0.6", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusolver-cu12", version = "11.7.3.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", version = "11.7.4.40", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", version = "12.5.9.5", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, { name = "nvidia-nccl-cu12", version = "2.26.5", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nccl-cu12", version = "2.27.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", version = "12.9.41", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] @@ -2608,22 +2617,31 @@ name = "nvidia-cublas-cu12" version = "12.6.4.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322 }, - { url = "https://files.pythonhosted.org/packages/97/0d/f1f0cadbf69d5b9ef2e4f744c9466cb0a850741d08350736dfdb4aa89569/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668", size = 390794615 }, { url = "https://files.pythonhosted.org/packages/84/f7/985e9bdbe3e0ac9298fcc8cfa51a392862a46a0ffaccbbd56939b62a9c83/nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8", size = 434535301 }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124 }, + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921 }, +] + [[package]] name = "nvidia-cublas-cu12" version = "12.9.0.13" @@ -2646,24 +2664,31 @@ name = "nvidia-cuda-cupti-cu12" version = "12.6.80" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/8b/2f6230cb715646c3a9425636e513227ce5c93c4d65823a734f4bb86d43c3/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc", size = 8236764 }, - { url = "https://files.pythonhosted.org/packages/25/0f/acb326ac8fd26e13c799e0b4f3b2751543e1834f04d62e729485872198d4/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4", size = 8236756 }, - { url = "https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132", size = 8917980 }, - { url = "https://files.pythonhosted.org/packages/a5/24/120ee57b218d9952c379d1e026c4479c9ece9997a4fb46303611ee48f038/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73", size = 8917972 }, { url = "https://files.pythonhosted.org/packages/1c/81/7796f096afaf726796b1b648f3bc80cafc61fe7f77f44a483c89e6c5ef34/nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a", size = 5724175 }, ] +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318 }, + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621 }, +] + [[package]] name = "nvidia-cuda-cupti-cu12" version = "12.9.19" @@ -2693,10 +2718,10 @@ wheels = [ [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.6.77" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380 }, + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029 }, ] [[package]] @@ -2704,24 +2729,31 @@ name = "nvidia-cuda-runtime-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/ea/590b2ac00d772a8abd1c387a92b46486d2679ca6622fd25c18ff76265663/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd", size = 908052 }, - { url = "https://files.pythonhosted.org/packages/b7/3d/159023799677126e20c8fd580cca09eeb28d5c5a624adc7f793b9aa8bbfa/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e", size = 908040 }, - { url = "https://files.pythonhosted.org/packages/e1/23/e717c5ac26d26cf39a27fbc076240fad2e3b817e5889d671b67f4f9f49c5/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7", size = 897690 }, - { url = "https://files.pythonhosted.org/packages/f0/62/65c05e161eeddbafeca24dc461f47de550d9fa8a7e04eb213e32b55cfd99/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8", size = 897678 }, { url = "https://files.pythonhosted.org/packages/fa/76/4c80fa138333cc975743fd0687a745fccb30d167f906f13c1c7f9a85e5ea/nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f", size = 891773 }, ] +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265 }, + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765 }, +] + [[package]] name = "nvidia-cuda-runtime-cu12" version = "12.9.37" @@ -2744,22 +2776,17 @@ name = "nvidia-cudnn-cu12" version = "9.5.1.17" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/99/93/a201a12d3ec1caa8c6ac34c1c2f9eeb696b886f0c36ff23c638b46603bd0/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def", size = 570523509 }, - { url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386 }, { url = "https://files.pythonhosted.org/packages/b6/b2/3f60d15f037fa5419d9d7f788b100ef33ea913ae5315c87ca6d6fa606c35/nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8", size = 565440743 }, ] @@ -2783,32 +2810,59 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/ec/79464a7371a028d1f443b8516b55cb2f70bb91bd3b2f2a831d707c003ccf/nvidia_cudnn_cu12-9.10.1.4-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:df73c4dab84df2c54f0a40e6427cde26e8d80feeffef02d749ee42d7da3c8204", size = 706752133 }, ] +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878 }, + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467 }, +] + [[package]] name = "nvidia-cufft-cu12" version = "11.3.0.4" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/37/c50d2b2f2c07e146776389e3080f4faf70bcc4fa6e19d65bb54ca174ebc3/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6", size = 200164144 }, - { url = "https://files.pythonhosted.org/packages/ce/f5/188566814b7339e893f8d210d3a5332352b1409815908dad6a363dcceac1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8510990de9f96c803a051822618d42bf6cb8f069ff3f48d93a8486efdacb48fb", size = 200164135 }, - { url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632 }, - { url = "https://files.pythonhosted.org/packages/60/de/99ec247a07ea40c969d904fc14f3a356b3e2a704121675b75c366b694ee1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca", size = 200221622 }, { url = "https://files.pythonhosted.org/packages/b4/38/36fd800cec8f6e89b7c1576edaaf8076e69ec631644cdbc1b5f2e2b5a9df/nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464", size = 199356881 }, ] +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211 }, + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695 }, +] + [[package]] name = "nvidia-cufft-cu12" version = "11.4.0.6" @@ -2831,19 +2885,18 @@ wheels = [ [[package]] name = "nvidia-cufile-cu12" -version = "1.11.1.6" +version = "1.13.1.3" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/66/cc9876340ac68ae71b15c743ddb13f8b30d5244af344ec8322b449e35426/nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159", size = 1142103 }, + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834 }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.7.77" +version = "10.3.9.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010 }, - { url = "https://files.pythonhosted.org/packages/4a/aa/2c7ff0b5ee02eaef890c0ce7d4f74bc30901871c5e45dee1ae6d0083cd80/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117", size = 56279000 }, + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976 }, ] [[package]] @@ -2851,29 +2904,41 @@ name = "nvidia-cusolver-cu12" version = "11.7.1.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/93/17/dbe1aa865e4fdc7b6d4d0dd308fdd5aaab60f939abfc0ea1954eac4fb113/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0", size = 157833628 }, - { url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790 }, - { url = "https://files.pythonhosted.org/packages/9f/81/baba53585da791d043c10084cf9553e074548408e04ae884cfe9193bd484/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6", size = 158229780 }, - { url = "https://files.pythonhosted.org/packages/7c/5f/07d0ba3b7f19be5a5ec32a8679fc9384cfd9fc6c869825e93be9f28d6690/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e", size = 157833630 }, { url = "https://files.pythonhosted.org/packages/d4/53/fff50a0808df7113d77e3bbc7c2b7eaed6f57d5eb80fbe93ead2aea1e09a/nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7", size = 149287877 }, ] +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841 }, + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905 }, +] + [[package]] name = "nvidia-cusolver-cu12" version = "11.7.4.40" @@ -2901,27 +2966,37 @@ name = "nvidia-cusparse-cu12" version = "12.5.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/eb/6681efd0aa7df96b4f8067b3ce7246833dd36830bb4cec8896182773db7d/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887", size = 216451147 }, - { url = "https://files.pythonhosted.org/packages/d3/56/3af21e43014eb40134dea004e8d0f1ef19d9596a39e4d497d5a7de01669f/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7aa32fa5470cf754f72d1116c7cbc300b4e638d3ae5304cfa4a638a5b87161b1", size = 216451135 }, - { url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367 }, - { url = "https://files.pythonhosted.org/packages/43/ac/64c4316ba163e8217a99680c7605f779accffc6a4bcd0c778c12948d3707/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f", size = 216561357 }, { url = "https://files.pythonhosted.org/packages/45/ef/876ad8e4260e1128e6d4aac803d9d51baf3791ebdb4a9b8d9b8db032b4b0/nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20", size = 213712630 }, ] +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129 }, + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466 }, +] + [[package]] name = "nvidia-cusparse-cu12" version = "12.5.9.5" @@ -2944,10 +3019,10 @@ wheels = [ [[package]] name = "nvidia-cusparselt-cu12" -version = "0.6.3" +version = "0.7.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796 }, + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691 }, ] [[package]] @@ -2964,20 +3039,13 @@ name = "nvidia-nccl-cu12" version = "2.26.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/5b/ca2f213f637305633814ae8c36b153220e40a07ea001966dcd87391f3acb/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522", size = 291671495 }, - { url = "https://files.pythonhosted.org/packages/67/ca/f42388aed0fddd64ade7493dbba36e1f534d4e6fdbdd355c6a90030ae028/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6", size = 201319755 }, -] [[package]] name = "nvidia-nccl-cu12" @@ -2996,27 +3064,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/fb/ec4ac065d9b0d56f72eaf1d9b0df601e33da28197b32ca351dc05b342611/nvidia_nccl_cu12-2.26.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea5ed3e053c735f16809bee7111deac62ac35b10128a8c102960a0462ce16cbe", size = 318069637 }, ] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/7b/8354b784cf73b0ba51e566b4baba3ddd44fe8288a3d39ef1e06cd5417226/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9ddf1a245abc36c550870f26d537a9b6087fb2e2e3d6e0ef03374c6fd19d984f", size = 322397768 }, + { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134 }, +] + [[package]] name = "nvidia-nvjitlink-cu12" version = "12.6.85" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971 }, - { url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338 }, { url = "https://files.pythonhosted.org/packages/89/76/93c1467b1387387440a4d25102d86b7794535449b689f8e2dc22c1c8ff7f/nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c", size = 161908572 }, ] +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836 }, + { url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204 }, +] + [[package]] name = "nvidia-nvjitlink-cu12" version = "12.9.41" @@ -3036,11 +3127,10 @@ wheels = [ [[package]] name = "nvidia-nvtx-cu12" -version = "12.6.77" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2", size = 89276 }, - { url = "https://files.pythonhosted.org/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265 }, + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954 }, ] [[package]] @@ -3180,7 +3270,7 @@ requires-dist = [ { name = "polars", specifier = ">=1.30.0" }, { name = "rich", specifier = ">=14.0.0" }, { name = "sentencepiece", specifier = ">=0.2.0" }, - { name = "torch", specifier = ">=2.7.0" }, + { name = "torch", specifier = ">=2.7.1" }, { name = "tqdm-loggable", specifier = ">=0.2" }, { name = "transformers", specifier = "==4.53.2" }, { name = "treescope", specifier = ">=0.1.7" }, @@ -4852,26 +4942,26 @@ wheels = [ [[package]] name = "torch" -version = "2.7.0" +version = "2.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", version = "9.10.2.21", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.3.83", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", version = "11.7.3.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", version = "2.27.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, @@ -4879,22 +4969,22 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/40/da/7378d16cc636697f2a94f791cb496939b60fb8580ddbbef22367db2c2274/torch-2.7.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2b7813e904757b125faf1a9a3154e1d50381d539ced34da1992f52440567c156", size = 99159397 }, - { url = "https://files.pythonhosted.org/packages/0e/6b/87fcddd34df9f53880fa1f0c23af7b6b96c935856473faf3914323588c40/torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fd5cfbb4c3bbadd57ad1b27d56a28008f8d8753733411a140fcfb84d7f933a25", size = 865183681 }, - { url = "https://files.pythonhosted.org/packages/13/85/6c1092d4b06c3db1ed23d4106488750917156af0b24ab0a2d9951830b0e9/torch-2.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:58df8d5c2eeb81305760282b5069ea4442791a6bbf0c74d9069b7b3304ff8a37", size = 212520100 }, - { url = "https://files.pythonhosted.org/packages/aa/3f/85b56f7e2abcfa558c5fbf7b11eb02d78a4a63e6aeee2bbae3bb552abea5/torch-2.7.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:0a8d43caa342b9986101ec5feb5bbf1d86570b5caa01e9cb426378311258fdde", size = 68569377 }, - { url = "https://files.pythonhosted.org/packages/aa/5e/ac759f4c0ab7c01feffa777bd68b43d2ac61560a9770eeac074b450f81d4/torch-2.7.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:36a6368c7ace41ad1c0f69f18056020b6a5ca47bedaca9a2f3b578f5a104c26c", size = 99013250 }, - { url = "https://files.pythonhosted.org/packages/9c/58/2d245b6f1ef61cf11dfc4aceeaacbb40fea706ccebac3f863890c720ab73/torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:15aab3e31c16feb12ae0a88dba3434a458874636f360c567caa6a91f6bfba481", size = 865042157 }, - { url = "https://files.pythonhosted.org/packages/44/80/b353c024e6b624cd9ce1d66dcb9d24e0294680f95b369f19280e241a0159/torch-2.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f56d4b2510934e072bab3ab8987e00e60e1262fb238176168f5e0c43a1320c6d", size = 212482262 }, - { url = "https://files.pythonhosted.org/packages/ee/8d/b2939e5254be932db1a34b2bd099070c509e8887e0c5a90c498a917e4032/torch-2.7.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:30b7688a87239a7de83f269333651d8e582afffce6f591fff08c046f7787296e", size = 68574294 }, - { url = "https://files.pythonhosted.org/packages/14/24/720ea9a66c29151b315ea6ba6f404650834af57a26b2a04af23ec246b2d5/torch-2.7.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:868ccdc11798535b5727509480cd1d86d74220cfdc42842c4617338c1109a205", size = 99015553 }, - { url = "https://files.pythonhosted.org/packages/4b/27/285a8cf12bd7cd71f9f211a968516b07dcffed3ef0be585c6e823675ab91/torch-2.7.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b52347118116cf3dff2ab5a3c3dd97c719eb924ac658ca2a7335652076df708", size = 865046389 }, - { url = "https://files.pythonhosted.org/packages/74/c8/2ab2b6eadc45554af8768ae99668c5a8a8552e2012c7238ded7e9e4395e1/torch-2.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:434cf3b378340efc87c758f250e884f34460624c0523fe5c9b518d205c91dd1b", size = 212490304 }, - { url = "https://files.pythonhosted.org/packages/28/fd/74ba6fde80e2b9eef4237fe668ffae302c76f0e4221759949a632ca13afa/torch-2.7.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:edad98dddd82220465b106506bb91ee5ce32bd075cddbcf2b443dfaa2cbd83bf", size = 68856166 }, - { url = "https://files.pythonhosted.org/packages/cb/b4/8df3f9fe6bdf59e56a0e538592c308d18638eb5f5dc4b08d02abb173c9f0/torch-2.7.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a885fc25afefb6e6eb18a7d1e8bfa01cc153e92271d980a49243b250d5ab6d9", size = 99091348 }, - { url = "https://files.pythonhosted.org/packages/9d/f5/0bd30e9da04c3036614aa1b935a9f7e505a9e4f1f731b15e165faf8a4c74/torch-2.7.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:176300ff5bc11a5f5b0784e40bde9e10a35c4ae9609beed96b4aeb46a27f5fae", size = 865104023 }, - { url = "https://files.pythonhosted.org/packages/d1/b7/2235d0c3012c596df1c8d39a3f4afc1ee1b6e318d469eda4c8bb68566448/torch-2.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d0ca446a93f474985d81dc866fcc8dccefb9460a29a456f79d99c29a78a66993", size = 212750916 }, - { url = "https://files.pythonhosted.org/packages/90/48/7e6477cf40d48cc0a61fa0d41ee9582b9a316b12772fcac17bc1a40178e7/torch-2.7.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:27f5007bdf45f7bb7af7f11d1828d5c2487e030690afb3d89a651fd7036a390e", size = 68575074 }, + { url = "https://files.pythonhosted.org/packages/8f/c4/3e7a3887eba14e815e614db70b3b529112d1513d9dae6f4d43e373360b7f/torch-2.8.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:220a06fd7af8b653c35d359dfe1aaf32f65aa85befa342629f716acb134b9710", size = 102073391 }, + { url = "https://files.pythonhosted.org/packages/5a/63/4fdc45a0304536e75a5e1b1bbfb1b56dd0e2743c48ee83ca729f7ce44162/torch-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c12fa219f51a933d5f80eeb3a7a5d0cbe9168c0a14bbb4055f1979431660879b", size = 888063640 }, + { url = "https://files.pythonhosted.org/packages/84/57/2f64161769610cf6b1c5ed782bd8a780e18a3c9d48931319f2887fa9d0b1/torch-2.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8c7ef765e27551b2fbfc0f41bcf270e1292d9bf79f8e0724848b1682be6e80aa", size = 241366752 }, + { url = "https://files.pythonhosted.org/packages/a4/5e/05a5c46085d9b97e928f3f037081d3d2b87fb4b4195030fc099aaec5effc/torch-2.8.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:5ae0524688fb6707c57a530c2325e13bb0090b745ba7b4a2cd6a3ce262572916", size = 73621174 }, + { url = "https://files.pythonhosted.org/packages/49/0c/2fd4df0d83a495bb5e54dca4474c4ec5f9c62db185421563deeb5dabf609/torch-2.8.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e2fab4153768d433f8ed9279c8133a114a034a61e77a3a104dcdf54388838705", size = 101906089 }, + { url = "https://files.pythonhosted.org/packages/99/a8/6acf48d48838fb8fe480597d98a0668c2beb02ee4755cc136de92a0a956f/torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2aca0939fb7e4d842561febbd4ffda67a8e958ff725c1c27e244e85e982173c", size = 887913624 }, + { url = "https://files.pythonhosted.org/packages/af/8a/5c87f08e3abd825c7dfecef5a0f1d9aa5df5dd0e3fd1fa2f490a8e512402/torch-2.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f4ac52f0130275d7517b03a33d2493bab3693c83dcfadf4f81688ea82147d2e", size = 241326087 }, + { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478 }, + { url = "https://files.pythonhosted.org/packages/10/4e/469ced5a0603245d6a19a556e9053300033f9c5baccf43a3d25ba73e189e/torch-2.8.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b2f96814e0345f5a5aed9bf9734efa913678ed19caf6dc2cddb7930672d6128", size = 101936856 }, + { url = "https://files.pythonhosted.org/packages/16/82/3948e54c01b2109238357c6f86242e6ecbf0c63a1af46906772902f82057/torch-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:65616ca8ec6f43245e1f5f296603e33923f4c30f93d65e103d9e50c25b35150b", size = 887922844 }, + { url = "https://files.pythonhosted.org/packages/e3/54/941ea0a860f2717d86a811adf0c2cd01b3983bdd460d0803053c4e0b8649/torch-2.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:659df54119ae03e83a800addc125856effda88b016dfc54d9f65215c3975be16", size = 241330968 }, + { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128 }, + { url = "https://files.pythonhosted.org/packages/15/0e/8a800e093b7f7430dbaefa80075aee9158ec22e4c4fc3c1a66e4fb96cb4f/torch-2.8.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:83c13411a26fac3d101fe8035a6b0476ae606deb8688e904e796a3534c197def", size = 102020139 }, + { url = "https://files.pythonhosted.org/packages/4a/15/5e488ca0bc6162c86a33b58642bc577c84ded17c7b72d97e49b5833e2d73/torch-2.8.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8f0a9d617a66509ded240add3754e462430a6c1fc5589f86c17b433dd808f97a", size = 887990692 }, + { url = "https://files.pythonhosted.org/packages/b4/a8/6a04e4b54472fc5dba7ca2341ab219e529f3c07b6941059fbf18dccac31f/torch-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a7242b86f42be98ac674b88a4988643b9bc6145437ec8f048fea23f72feb5eca", size = 241603453 }, + { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395 }, ] [[package]] @@ -4914,7 +5004,7 @@ wheels = [ [[package]] name = "torchvision" -version = "0.22.0" +version = "0.23.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, @@ -4922,22 +5012,22 @@ dependencies = [ { name = "torch" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/43/28bc858b022f6337326d75f4027d2073aad5432328f01ee1236d847f1b82/torchvision-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:191ea28321fc262d8aa1a7fe79c41ff2848864bf382f9f6ea45c41dde8313792", size = 1947828 }, - { url = "https://files.pythonhosted.org/packages/7e/71/ce9a303b94e64fe25d534593522ffc76848c4e64c11e4cbe9f6b8d537210/torchvision-0.22.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6c5620e10ffe388eb6f4744962106ed7cf1508d26e6fdfa0c10522d3249aea24", size = 2514016 }, - { url = "https://files.pythonhosted.org/packages/09/42/6908bff012a1dcc4fc515e52339652d7f488e208986542765c02ea775c2f/torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ce292701c77c64dd3935e3e31c722c3b8b176a75f76dc09b804342efc1db5494", size = 7447546 }, - { url = "https://files.pythonhosted.org/packages/e4/cf/8f9305cc0ea26badbbb3558ecae54c04a245429f03168f7fad502f8a5b25/torchvision-0.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:e4017b5685dbab4250df58084f07d95e677b2f3ed6c2e507a1afb8eb23b580ca", size = 1716472 }, - { url = "https://files.pythonhosted.org/packages/cb/ea/887d1d61cf4431a46280972de665f350af1898ce5006cd046326e5d0a2f2/torchvision-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31c3165418fe21c3d81fe3459e51077c2f948801b8933ed18169f54652796a0f", size = 1947826 }, - { url = "https://files.pythonhosted.org/packages/72/ef/21f8b6122e13ae045b8e49658029c695fd774cd21083b3fa5c3f9c5d3e35/torchvision-0.22.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8f116bc82e0c076e70ba7776e611ed392b9666aa443662e687808b08993d26af", size = 2514571 }, - { url = "https://files.pythonhosted.org/packages/7c/48/5f7617f6c60d135f86277c53f9d5682dfa4e66f4697f505f1530e8b69fb1/torchvision-0.22.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ce4dc334ebd508de2c534817c9388e928bc2500cf981906ae8d6e2ca3bf4727a", size = 7446522 }, - { url = "https://files.pythonhosted.org/packages/99/94/a015e93955f5d3a68689cc7c385a3cfcd2d62b84655d18b61f32fb04eb67/torchvision-0.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:24b8c9255c209ca419cc7174906da2791c8b557b75c23496663ec7d73b55bebf", size = 1716664 }, - { url = "https://files.pythonhosted.org/packages/e1/2a/9b34685599dcb341d12fc2730055155623db7a619d2415a8d31f17050952/torchvision-0.22.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ece17995857dd328485c9c027c0b20ffc52db232e30c84ff6c95ab77201112c5", size = 1947823 }, - { url = "https://files.pythonhosted.org/packages/77/77/88f64879483d66daf84f1d1c4d5c31ebb08e640411139042a258d5f7dbfe/torchvision-0.22.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:471c6dd75bb984c6ebe4f60322894a290bf3d4b195e769d80754f3689cd7f238", size = 2471592 }, - { url = "https://files.pythonhosted.org/packages/f7/82/2f813eaae7c1fae1f9d9e7829578f5a91f39ef48d6c1c588a8900533dd3d/torchvision-0.22.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2b839ac0610a38f56bef115ee5b9eaca5f9c2da3c3569a68cc62dbcc179c157f", size = 7446333 }, - { url = "https://files.pythonhosted.org/packages/58/19/ca7a4f8907a56351dfe6ae0a708f4e6b3569b5c61d282e3e7f61cf42a4ce/torchvision-0.22.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ada1c08b2f761443cd65b7c7b4aec9e2fc28f75b0d4e1b1ebc9d3953ebccc4d", size = 1716693 }, - { url = "https://files.pythonhosted.org/packages/6f/a7/f43e9c8d13118b4ffbaebea664c9338ab20fa115a908125afd2238ff16e7/torchvision-0.22.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:cdc96daa4658b47ce9384154c86ed1e70cba9d972a19f5de6e33f8f94a626790", size = 2137621 }, - { url = "https://files.pythonhosted.org/packages/6a/9a/2b59f5758ba7e3f23bc84e16947493bbce97392ec6d18efba7bdf0a3b10e/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:753d3c84eeadd5979a33b3b73a25ecd0aa4af44d6b45ed2c70d44f5e0ac68312", size = 2476555 }, - { url = "https://files.pythonhosted.org/packages/7d/40/a7bc2ab9b1e56d10a7fd9ae83191bb425fa308caa23d148f1c568006e02c/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b30e3ed29e4a61f7499bca50f57d8ebd23dfc52b14608efa17a534a55ee59a03", size = 7617924 }, - { url = "https://files.pythonhosted.org/packages/c1/7b/30d423bdb2546250d719d7821aaf9058cc093d165565b245b159c788a9dd/torchvision-0.22.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e5d680162694fac4c8a374954e261ddfb4eb0ce103287b0f693e4e9c579ef957", size = 1638621 }, + { url = "https://files.pythonhosted.org/packages/f0/d7/15d3d7bd8d0239211b21673d1bac7bc345a4ad904a8e25bb3fd8a9cf1fbc/torchvision-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:49aa20e21f0c2bd458c71d7b449776cbd5f16693dd5807195a820612b8a229b7", size = 1856884 }, + { url = "https://files.pythonhosted.org/packages/dd/14/7b44fe766b7d11e064c539d92a172fa9689a53b69029e24f2f1f51e7dc56/torchvision-0.23.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:01dc33ee24c79148aee7cdbcf34ae8a3c9da1674a591e781577b716d233b1fa6", size = 2395543 }, + { url = "https://files.pythonhosted.org/packages/79/9c/fcb09aff941c8147d9e6aa6c8f67412a05622b0c750bcf796be4c85a58d4/torchvision-0.23.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:35c27941831b653f5101edfe62c03d196c13f32139310519e8228f35eae0e96a", size = 8628388 }, + { url = "https://files.pythonhosted.org/packages/93/40/3415d890eb357b25a8e0a215d32365a88ecc75a283f75c4e919024b22d97/torchvision-0.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:09bfde260e7963a15b80c9e442faa9f021c7e7f877ac0a36ca6561b367185013", size = 1600741 }, + { url = "https://files.pythonhosted.org/packages/df/1d/0ea0b34bde92a86d42620f29baa6dcbb5c2fc85990316df5cb8f7abb8ea2/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0e2c04a91403e8dd3af9756c6a024a1d9c0ed9c0d592a8314ded8f4fe30d440", size = 1856885 }, + { url = "https://files.pythonhosted.org/packages/e2/00/2f6454decc0cd67158c7890364e446aad4b91797087a57a78e72e1a8f8bc/torchvision-0.23.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6dd7c4d329a0e03157803031bc856220c6155ef08c26d4f5bbac938acecf0948", size = 2396614 }, + { url = "https://files.pythonhosted.org/packages/e4/b5/3e580dcbc16f39a324f3dd71b90edbf02a42548ad44d2b4893cc92b1194b/torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4e7d31c43bc7cbecbb1a5652ac0106b436aa66e26437585fc2c4b2cf04d6014c", size = 8627108 }, + { url = "https://files.pythonhosted.org/packages/82/c1/c2fe6d61e110a8d0de2f94276899a2324a8f1e6aee559eb6b4629ab27466/torchvision-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:a2e45272abe7b8bf0d06c405e78521b5757be1bd0ed7e5cd78120f7fdd4cbf35", size = 1600723 }, + { url = "https://files.pythonhosted.org/packages/91/37/45a5b9407a7900f71d61b2b2f62db4b7c632debca397f205fdcacb502780/torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1c37e325e09a184b730c3ef51424f383ec5745378dc0eca244520aca29722600", size = 1856886 }, + { url = "https://files.pythonhosted.org/packages/ac/da/a06c60fc84fc849377cf035d3b3e9a1c896d52dbad493b963c0f1cdd74d0/torchvision-0.23.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2f7fd6c15f3697e80627b77934f77705f3bc0e98278b989b2655de01f6903e1d", size = 2353112 }, + { url = "https://files.pythonhosted.org/packages/a0/27/5ce65ba5c9d3b7d2ccdd79892ab86a2f87ac2ca6638f04bb0280321f1a9c/torchvision-0.23.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a76fafe113b2977be3a21bf78f115438c1f88631d7a87203acb3dd6ae55889e6", size = 8627658 }, + { url = "https://files.pythonhosted.org/packages/1f/e4/028a27b60aa578a2fa99d9d7334ff1871bb17008693ea055a2fdee96da0d/torchvision-0.23.0-cp313-cp313-win_amd64.whl", hash = "sha256:07d069cb29691ff566e3b7f11f20d91044f079e1dbdc9d72e0655899a9b06938", size = 1600749 }, + { url = "https://files.pythonhosted.org/packages/05/35/72f91ad9ac7c19a849dedf083d347dc1123f0adeb401f53974f84f1d04c8/torchvision-0.23.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2df618e1143805a7673aaf82cb5720dd9112d4e771983156aaf2ffff692eebf9", size = 2047192 }, + { url = "https://files.pythonhosted.org/packages/1d/9d/406cea60a9eb9882145bcd62a184ee61e823e8e1d550cdc3c3ea866a9445/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a3299d2b1d5a7aed2d3b6ffb69c672ca8830671967eb1cee1497bacd82fe47b", size = 2359295 }, + { url = "https://files.pythonhosted.org/packages/2b/f4/34662f71a70fa1e59de99772142f22257ca750de05ccb400b8d2e3809c1d/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:76bc4c0b63d5114aa81281390f8472a12a6a35ce9906e67ea6044e5af4cab60c", size = 8800474 }, + { url = "https://files.pythonhosted.org/packages/6e/f5/b5a2d841a8d228b5dbda6d524704408e19e7ca6b7bb0f24490e081da1fa1/torchvision-0.23.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b9e2dabf0da9c8aa9ea241afb63a8f3e98489e706b22ac3f30416a1be377153b", size = 1527667 }, ] [[package]] @@ -5039,16 +5129,16 @@ wheels = [ [[package]] name = "triton" -version = "3.3.0" +version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "setuptools", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/c5/4874a81131cc9e934d88377fbc9d24319ae1fb540f3333b4e9c696ebc607/triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3161a2bf073d6b22c4e2f33f951f3e5e3001462b2570e6df9cd57565bdec2984", size = 156528461 }, - { url = "https://files.pythonhosted.org/packages/11/53/ce18470914ab6cfbec9384ee565d23c4d1c55f0548160b1c7b33000b11fd/triton-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b68c778f6c4218403a6bd01be7484f6dc9e20fe2083d22dd8aef33e3b87a10a3", size = 156504509 }, - { url = "https://files.pythonhosted.org/packages/7d/74/4bf2702b65e93accaa20397b74da46fb7a0356452c1bb94dbabaf0582930/triton-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47bc87ad66fa4ef17968299acacecaab71ce40a238890acc6ad197c3abe2b8f1", size = 156516468 }, - { url = "https://files.pythonhosted.org/packages/0a/93/f28a696fa750b9b608baa236f8225dd3290e5aff27433b06143adc025961/triton-3.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce4700fc14032af1e049005ae94ba908e71cd6c2df682239aed08e49bc71b742", size = 156580729 }, + { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138 }, + { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068 }, + { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223 }, + { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780 }, ] [[package]] From b9b7f6a96128787de7931f0617a3b27f9ab89940 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 09:37:48 -0700 Subject: [PATCH 10/32] add preprocess --- scripts/train_pytorch.py | 152 +++++++++++++++++++---- src/openpi/models/model.py | 147 ++++++++++++++++++++++ src/openpi/models/pi0.py | 2 +- src/openpi/models_pytorch/pi0_pytorch.py | 3 +- src/openpi/shared/image_tools.py | 79 ++++++++++++ src/openpi/training/config.py | 4 +- 6 files changed, 361 insertions(+), 26 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 1a46fe1..78bd389 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -144,6 +144,11 @@ def setup_ddp(): if use_ddp and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dist.init_process_group(backend=backend, init_method="env://") + + # Set up debugging environment variables for DDP issues + if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): @@ -214,6 +219,16 @@ def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Tuple[list[to return batch +def get_model_state_dict(model): + """Get state dict from model, handling DDP wrapper.""" + return model.module.state_dict() if isinstance(model, DDP) else model.state_dict() + + +def get_model_parameters(model): + """Get parameters from model, handling DDP wrapper.""" + return model.module.parameters() if isinstance(model, DDP) else model.parameters() + + def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_interval=None, ema_model=None): """Save a checkpoint with model state, optimizer state, EMA state, and metadata.""" if not is_main: @@ -228,7 +243,7 @@ def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_in os.makedirs(ckpt_dir, exist_ok=True) # Save model state - state_dict = (model.module if isinstance(model, DDP) else model).state_dict() + state_dict = get_model_state_dict(model) torch.save(state_dict, os.path.join(ckpt_dir, "pytorch_model.pt")) # Save optimizer state @@ -259,30 +274,30 @@ def load_checkpoint(model, optimizer, config, device, ema_model=None): for d in config.checkpoint_dir.iterdir(): if d.is_dir() and d.name.isdigit(): checkpoint_steps.append(int(d.name)) - + if not checkpoint_steps: raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir}") - + latest_step = max(checkpoint_steps) ckpt_dir = os.path.join(config.checkpoint_dir, f"{latest_step}") - + # Load model state model_state_dict = torch.load(os.path.join(ckpt_dir, "pytorch_model.pt"), map_location=device) (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) - + # Load optimizer state optimizer_state_dict = torch.load(os.path.join(ckpt_dir, "optimizer.pt"), map_location=device) optimizer.load_state_dict(optimizer_state_dict) - + # Load EMA state if available if ema_model is not None and os.path.exists(os.path.join(ckpt_dir, "ema_model.pt")): ema_state_dict = torch.load(os.path.join(ckpt_dir, "ema_model.pt"), map_location=device) ema_model.load_state_dict(ema_state_dict) logging.info(f"Loaded EMA state from checkpoint") - + # Load metadata metadata = torch.load(os.path.join(ckpt_dir, "metadata.pt"), map_location=device) - + logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") return metadata["global_step"] @@ -297,6 +312,57 @@ def get_latest_checkpoint_step(config): return max(checkpoint_steps) if checkpoint_steps else None +def debug_unused_parameters(model, device): + """Debug function to identify unused parameters in the model.""" + if isinstance(model, DDP): + model = model.module + + logging.info("Checking for potentially unused parameters...") + + # Get all parameter names and their indices + param_info = {} + idx = 0 + for name, param in model.named_parameters(): + if param.requires_grad: + param_info[idx] = name + idx += 1 + + logging.info(f"Total trainable parameters: {len(param_info)}") + + # Check which parameters have gradients after a forward pass + # This is a diagnostic function that can be called if needed + return param_info + + +def check_model_parameters(model, device): + """Check for unused parameters and provide debugging information.""" + if isinstance(model, DDP): + model = model.module + + total_params = 0 + used_params = 0 + + for name, param in model.named_parameters(): + total_params += param.numel() + if param.requires_grad: + used_params += param.numel() + + logging.info(f"Model parameters: {total_params:,} total, {used_params:,} trainable") + + # Check for parameters that might be unused + unused_params = [] + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is None: + unused_params.append(name) + + if unused_params: + logging.warning(f"Found {len(unused_params)} parameters that might be unused:") + for name in unused_params[:10]: # Show first 10 + logging.warning(f" - {name}") + if len(unused_params) > 10: + logging.warning(f" ... and {len(unused_params) - 10} more") + + def setup_memory_optimizations(model, device, enable_gradient_checkpointing=False): """Setup memory optimization techniques for the model.""" if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): @@ -410,12 +476,17 @@ def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, grad model_cfg = config.model model = PI0Pytorch(model_cfg).to(device) - + # Apply memory optimizations setup_memory_optimizations(model, device, enable_gradient_checkpointing) - + + # Check model parameters for debugging + if is_main: + check_model_parameters(model, device) + if use_ddp: - model = DDP(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=False) + # Enable unused parameter detection to handle cases where some parameters don't participate in loss + model = DDP(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=True) # Load weights from weight_loader if specified (for fine-tuning) if isinstance(config.weight_loader, str): @@ -445,10 +516,20 @@ def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, grad # Initialize EMA if specified in config ema_model = None if config.ema_decay is not None: - ema_model = PI0Pytorch(model_cfg).to(device) - ema_model.load_state_dict(model.state_dict()) - ema_model.eval() - logging.info(f"Initialized EMA with decay {config.ema_decay}") + try: + ema_model = PI0Pytorch(model_cfg).to(device) + + # Get the correct state dict from the main model + main_model_state_dict = get_model_state_dict(model) + + # Load the state dict into EMA model + ema_model.load_state_dict(main_model_state_dict) + ema_model.eval() + logging.info(f"Initialized EMA with decay {config.ema_decay}") + except Exception as e: + logging.error(f"Failed to initialize EMA model: {e}") + logging.error("Continuing without EMA...") + ema_model = None # Load checkpoint if resuming global_step = 0 @@ -501,9 +582,24 @@ def lr_schedule(step: int): # Forward pass with mixed precision observation = _model.Observation.from_dict(batch) - with torch.amp.autocast('cuda', enabled=mixed_precision and torch.cuda.is_available()): - losses = model(observation, actions) - loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation + try: + with torch.amp.autocast('cuda', enabled=mixed_precision and torch.cuda.is_available()): + losses = model(observation, actions) + # Ensure losses is a tensor and handle different return types + if isinstance(losses, (list, tuple)): + losses = torch.stack(losses) + elif not isinstance(losses, torch.Tensor): + losses = torch.tensor(losses, device=device, dtype=torch.float32) + + loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation + except RuntimeError as e: + if "Expected to have finished reduction" in str(e) or "did not receive grad" in str(e): + logging.error(f"DDP error on rank {dist.get_rank() if use_ddp else 0}: {e}") + logging.error("This usually indicates unused parameters in the model.") + logging.error("Try setting TORCH_DISTRIBUTED_DEBUG=DETAIL for more information.") + raise + else: + raise # Backward pass with gradient scaling scaler.scale(loss).backward() @@ -521,9 +617,15 @@ def lr_schedule(step: int): # Update EMA if enabled if ema_model is not None: - with torch.no_grad(): - for param, ema_param in zip(model.parameters(), ema_model.parameters()): - ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) + try: + with torch.no_grad(): + # Get parameters from the correct model structure + main_model_params = get_model_parameters(model) + for param, ema_param in zip(main_model_params, ema_model.parameters()): + ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) + except Exception as e: + logging.warning(f"Failed to update EMA model: {e}") + # Continue training without EMA update # Collect stats (only on accumulation steps) if (global_step + 1) % gradient_accumulation_steps == 0 and is_main: @@ -602,10 +704,16 @@ def main(): help="Maximum GPU memory usage in GB (default: None, auto-detect)") parser.add_argument("--enable_gradient_checkpointing", action="store_true", default=False, help="Enable gradient checkpointing for memory optimization") + parser.add_argument("--ddp_debug_level", type=str, default="INFO", choices=["INFO", "DETAIL", "OFF"], + help="DDP debugging level (default: INFO)") args, _ = parser.parse_known_args() - + # Handle mixed precision flag mixed_precision = args.mixed_precision and not args.no_mixed_precision + + # Set DDP debug level + if args.ddp_debug_level != "OFF": + os.environ["TORCH_DISTRIBUTED_DEBUG"] = args.ddp_debug_level train_loop(config, ckpt_save_interval=args.ckpt_save_interval, diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index f6ee554..f76ed93 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -208,6 +208,153 @@ def preprocess_observation( ) +def preprocess_observation_pytorch( + observation: Observation, + *, + train: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +) -> Observation: + """PyTorch version of preprocess_observation. Preprocesses observations with PyTorch tensors by performing + image resizing (if necessary) and filling in a default image mask (if necessary). + + Note: Image augmentation is not implemented for PyTorch tensors as augmax is JAX-specific. + """ + + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + for key in image_keys: + image = observation.images[key] + + # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats + # Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + image = image.permute(0, 2, 3, 1) + + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad_torch(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + image = image / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = image.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + start_h = torch.randint(0, max_h + 1, (1,)).item() + start_w = torch.randint(0, max_w + 1, (1,)).item() + image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] + + # Resize back to original size + image = torch.nn.functional.interpolate( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode='bilinear', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + angle = torch.rand(1).item() * 10 - 5 # Random angle between -5 and 5 degrees + if abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = torch.tensor(angle * torch.pi / 180.0, device=image.device) + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=image.device) + grid_y = torch.linspace(-1, 1, height, device=image.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + image = torch.nn.functional.grid_sample( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + brightness_factor = 0.7 + torch.rand(1).item() * 0.6 # Random factor between 0.7 and 1.3 + image = image * brightness_factor + + # Random contrast + contrast_factor = 0.6 + torch.rand(1).item() * 0.8 # Random factor between 0.6 and 1.4 + mean = image.mean(dim=[1, 2, 3], keepdim=True) + image = (image - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + saturation_factor = 0.5 + torch.rand(1).item() * 1.0 # Random factor between 0.5 and 1.5 + gray = image.mean(dim=-1, keepdim=True) + image = gray + (image - gray) * saturation_factor + + # Clamp values to [0, 1] + image = torch.clamp(image, 0, 1) + + # Back to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) + else: + out_masks[key] = observation.image_masks[key] + + return Observation( + images=out_images, + image_masks=out_masks, + state=observation.state, + tokenized_prompt=observation.tokenized_prompt, + tokenized_prompt_mask=observation.tokenized_prompt_mask, + token_ar_mask=observation.token_ar_mask, + token_loss_mask=observation.token_loss_mask, + ) + + @dataclasses.dataclass(frozen=True) class BaseModelConfig(abc.ABC): """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` diff --git a/src/openpi/models/pi0.py b/src/openpi/models/pi0.py index c84f5c1..7160983 100644 --- a/src/openpi/models/pi0.py +++ b/src/openpi/models/pi0.py @@ -197,7 +197,7 @@ def compute_loss( time: at.Float[at.Array, "*b"] | None = None ) -> at.Float[at.Array, "*b ah"]: preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) - #observation = _model.preprocess_observation(preprocess_rng, observation, train=train) + observation = _model.preprocess_observation(preprocess_rng, observation, train=train) batch_shape = actions.shape[:-2] # Use provided noise and time if available, otherwise generate them diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 89bc0dd..6c08f69 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -7,6 +7,7 @@ import openpi.models.gemma as _gemma from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel +import openpi.models.model as _model def get_safe_dtype(target_dtype, device_type): @@ -232,7 +233,7 @@ def embed_suffix(self, state, noisy_actions, timestep): def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - # observation = _model.preprocess_observation_pytorch(observation, train=True) + observation = _model.preprocess_observation_pytorch(observation, train=True) images = list(observation.images.values()) img_masks = list(observation.image_masks.values()) lang_tokens = observation.tokenized_prompt diff --git a/src/openpi/shared/image_tools.py b/src/openpi/shared/image_tools.py index 4d63e1c..78f0e17 100644 --- a/src/openpi/shared/image_tools.py +++ b/src/openpi/shared/image_tools.py @@ -2,6 +2,8 @@ import jax import jax.numpy as jnp +import torch +import torch.nn.functional as F import openpi.shared.array_typing as at @@ -48,3 +50,80 @@ def resize_with_pad( if not has_batch_dim: padded_images = padded_images[0] return padded_images + + +def resize_with_pad_torch( + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + # Convert to channels-first for torch operations + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode='constant', + value=constant_value + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + if batch_size == 1 and images.shape[0] == 1: + padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added + + return padded_images \ No newline at end of file diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 723d212..6235e32 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -724,7 +724,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=1, + batch_size=64, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -746,7 +746,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=1, + batch_size=64, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, From d2aba1c9eb0b1ee449416a15ae23fdec464b8228 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 12:28:55 -0700 Subject: [PATCH 11/32] fixes --- scripts/train_single_example.py | 48 +++++++++++++++------- src/openpi/models_pytorch/gemma_pytorch.py | 32 +++++++-------- src/openpi/models_pytorch/pi0_pytorch.py | 5 +-- 3 files changed, 52 insertions(+), 33 deletions(-) diff --git a/scripts/train_single_example.py b/scripts/train_single_example.py index 2a23d48..9418342 100644 --- a/scripts/train_single_example.py +++ b/scripts/train_single_example.py @@ -11,6 +11,7 @@ import jax.numpy as jnp import flax.nnx as nnx import flax +from unittest.mock import patch from openpi.models import model as _model from openpi.models.pi0_config import Pi0Config @@ -83,6 +84,16 @@ def create_fixed_noise_and_time(batch_size, action_horizon, action_dim): return noise, time +def mock_preprocess_observation(rng, observation, **kwargs): + """Mock function that returns observation unchanged to disable preprocessing.""" + return observation + + +def mock_preprocess_observation_pytorch(observation, **kwargs): + """Mock function that returns observation unchanged to disable preprocessing.""" + return observation + + def test_pytorch_single_example(noise, time): """Test PyTorch training on single example.""" print("\n=== Testing PyTorch on Single Example ===") @@ -136,13 +147,15 @@ def test_pytorch_single_example(noise, time): model.eval() with torch.no_grad(): #try: - losses = model(observation, actions, noise=noise_tensor, time=time_tensor) - print(f"PyTorch forward pass successful!") - print(f"Losses shape: {losses.shape}") - print(f"Losses dtype: {losses.dtype}") - # mean_loss = losses.to(torch.float32).mean().item() - # print(f"Mean loss: {mean_loss:.6f}") - return True, losses + # Use mock to disable preprocessing + with patch('openpi.models.model.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): + losses = model(observation, actions, noise=noise_tensor, time=time_tensor) + print(f"PyTorch forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + mean_loss = losses.to(torch.float32).mean().item() + print(f"Mean loss: {mean_loss:.6f}") + return True, losses # except Exception as e: # print(f"PyTorch forward pass failed: {e}") # return False, None @@ -384,13 +397,15 @@ def adapt_nested_params(params_dict, key_path=""): # Test forward pass with fixed noise and time # try: # Use the modified compute_loss method that accepts external noise and time - losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) - print(f"JAX forward pass successful!") - print(f"Losses shape: {losses.shape}") - print(f"Losses dtype: {losses.dtype}") - mean_loss = jnp.mean(losses).item() - print(f"Mean loss: {mean_loss:.6f}") - return True, losses + # Use mock to disable preprocessing + with patch('openpi.models.model.preprocess_observation', side_effect=mock_preprocess_observation): + losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) + print(f"JAX forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + mean_loss = jnp.mean(losses).item() + print(f"Mean loss: {mean_loss:.6f}") + return True, losses # except Exception as e: # print(f"JAX forward pass failed: {e}") # return False, None @@ -405,6 +420,9 @@ def compare_losses(pytorch_loss, jax_loss): print("šŸ“Š LOSS COMPARISON") print("=" * 70) + print(f"PyTorch loss: {pytorch_loss}") + print(f"JAX loss: {jax_loss}") + # # Handle tensor inputs by computing mean if needed # if hasattr(pytorch_loss, 'mean'): # pytorch_mean = pytorch_loss.to(torch.float32).mean().item() @@ -502,6 +520,7 @@ def main(): print("šŸ“ Loading pre-trained weights for both models...") print("šŸŽÆ Using fixed noise and time values for deterministic comparison...") print("šŸ”§ Debug mode: JAX model will use only 1 encoder layer for faster debugging...") + print("🚫 Preprocessing disabled: Image augmentations and resizing are bypassed for fair comparison...") # Generate fixed noise and time noise, time = create_fixed_noise_and_time( @@ -542,6 +561,7 @@ def main(): print("3. If losses differ significantly, investigate the differences") print("4. Check if the noise and time handling is consistent between implementations") print("5. Use the same example in full training runs") + print("6. Note: Preprocessing (image augmentations) is disabled for this comparison") if __name__ == "__main__": diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 9272a34..5a27cfa 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -106,7 +106,7 @@ def forward( past_key_values: list[torch.FloatTensor] | Cache | None = None, inputs_embeds: list[torch.FloatTensor] = None, use_cache: bool | None = None, - adarms_cond: list[torch.Tensor] | None = None, + adarms_cond: list[torch.Tensor] = [None, None], ): if inputs_embeds[1] is None: prefix_output = self.paligemma.language_model.forward( @@ -120,7 +120,7 @@ def forward( prefix_past_key_values = prefix_output.past_key_values prefix_output = prefix_output.last_hidden_state suffix_output = None - if inputs_embeds[0] is None: + elif inputs_embeds[0] is None: suffix_output = self.gemma_expert.model.forward( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, @@ -150,9 +150,9 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states.append(query_state) key_states.append(key_state) @@ -160,20 +160,20 @@ def forward( # B,L,H,D with L sequence length, H number of heads, D head dim # concatenate on the number of embeddings/tokens - query_states = torch.cat(query_states, dim=1) - key_states = torch.cat(key_states, dim=1) - value_states = torch.cat(value_states, dim=1) + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) - query_states = apply_rope(query_states, position_ids) - key_states = apply_rope(key_states, position_ids) + # query_states = apply_rope(query_states, position_ids) + # key_states = apply_rope(key_states, position_ids) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + # query_states = query_states.transpose(1, 2) + # key_states = key_states.transpose(1, 2) + # value_states = value_states.transpose(1, 2) - # dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) - # cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) - # query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) + dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) + cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) batch_size = query_states.shape[0] scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 6c08f69..13f42b2 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -105,6 +105,7 @@ def __init__(self, config): torch.set_float32_matmul_precision('high') self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") + self.forward = torch.compile(self.forward, mode="reduce-overhead") def sample_noise(self, shape, device): noise = torch.normal( @@ -277,9 +278,7 @@ def forward(self, observation, actions, noise=None, time=None) -> Tensor: v_t = self.action_out_proj(suffix_out) - #losses = F.mse_loss(u_t, v_t, reduction="none") - - losses = torch.square(v_t - u_t) + losses = F.mse_loss(u_t, v_t, reduction="none") return losses @torch.no_grad() From 6fa3f8d3db9dd9dae48e04e49924301da5c5fa1f Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 13:48:07 -0700 Subject: [PATCH 12/32] fix resume --- examples/convert_jax_model_to_pytorch.py | 4 +- scripts/train_pytorch.py | 94 +++++++----- src/openpi/models/model.py | 177 ++++++++++++++++++++++- src/openpi/models_pytorch/pi0_pytorch.py | 9 +- 4 files changed, 236 insertions(+), 48 deletions(-) diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index 24facc1..5031a85 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -8,10 +8,10 @@ Usage: # Just inspect keys: - python convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only # Convert to PyTorch: - python convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output Example: # pi0_droid diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 78bd389..cdaeb17 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -21,10 +21,12 @@ python scripts/train_pytorch.py --exp_name --ckpt_save_interval Example: python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint Multi-GPU (single node): torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name Example: torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume Multi-Node Training: # On master node (node 0): torchrun --nnodes= --nproc_per_node= --rdzv_id= --rdzv_backend=c10d --rdzv_endpoint=: scripts/train_pytorch.py --exp_name @@ -65,6 +67,9 @@ by the selected TrainConfig (e.g., `LeRobot*` configs for real datasets or `FakeDataConfig`). - Supports Weights & Biases (wandb) logging for experiment tracking and visualization. - Checkpoints include model state, optimizer state, and training metadata for complete resume capability. +- Checkpoints are saved in experiment-specific directories: // +- Resume functionality automatically finds the latest checkpoint for the specified experiment name. +- Checkpoint loading handles both PyTorch and JAX/Flax checkpoints for compatibility. - For optimal multi-node performance, ensure high-bandwidth network connectivity (e.g., InfiniBand). - Monitor GPU utilization and network bandwidth during multi-node training. - Memory optimizations can significantly reduce GPU memory usage while maintaining training quality. @@ -239,27 +244,28 @@ def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_in # Only save if it's time to save or if it's the final step if (global_step % save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: - ckpt_dir = os.path.join(config.checkpoint_dir, f"{global_step}") - os.makedirs(ckpt_dir, exist_ok=True) + # Ensure checkpoint_dir is a Path object and create the step-specific directory + ckpt_dir = config.checkpoint_dir / f"{global_step}" + ckpt_dir.mkdir(parents=True, exist_ok=True) # Save model state state_dict = get_model_state_dict(model) - torch.save(state_dict, os.path.join(ckpt_dir, "pytorch_model.pt")) + torch.save(state_dict, ckpt_dir / "pytorch_model.pt") # Save optimizer state - torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer.pt")) + torch.save(optimizer.state_dict(), ckpt_dir / "optimizer.pt") # Save EMA state if available if ema_model is not None: - torch.save(ema_model.state_dict(), os.path.join(ckpt_dir, "ema_model.pt")) + torch.save(ema_model.state_dict(), ckpt_dir / "ema_model.pt") - # Save training metadata + # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) metadata = { "global_step": global_step, "config": dataclasses.asdict(config), "timestamp": time.time(), } - torch.save(metadata, os.path.join(ckpt_dir, "metadata.pt")) + torch.save(metadata, ckpt_dir / "metadata.pt") logging.info(f"Saved checkpoint at step {global_step} -> {ckpt_dir}") @@ -268,44 +274,49 @@ def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_in wandb.log({"checkpoint_step": global_step}, step=global_step) -def load_checkpoint(model, optimizer, config, device, ema_model=None): +def load_checkpoint(model, optimizer, checkpoint_dir, device, ema_model=None): """Load the latest checkpoint and return the global step.""" checkpoint_steps = [] - for d in config.checkpoint_dir.iterdir(): + for d in checkpoint_dir.iterdir(): if d.is_dir() and d.name.isdigit(): checkpoint_steps.append(int(d.name)) if not checkpoint_steps: - raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir}") + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") latest_step = max(checkpoint_steps) - ckpt_dir = os.path.join(config.checkpoint_dir, f"{latest_step}") + ckpt_dir = checkpoint_dir / f"{latest_step}" # Load model state - model_state_dict = torch.load(os.path.join(ckpt_dir, "pytorch_model.pt"), map_location=device) + model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device) (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) # Load optimizer state - optimizer_state_dict = torch.load(os.path.join(ckpt_dir, "optimizer.pt"), map_location=device) + optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device) optimizer.load_state_dict(optimizer_state_dict) # Load EMA state if available - if ema_model is not None and os.path.exists(os.path.join(ckpt_dir, "ema_model.pt")): - ema_state_dict = torch.load(os.path.join(ckpt_dir, "ema_model.pt"), map_location=device) + if ema_model is not None and (ckpt_dir / "ema_model.pt").exists(): + ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device) ema_model.load_state_dict(ema_state_dict) logging.info(f"Loaded EMA state from checkpoint") - # Load metadata - metadata = torch.load(os.path.join(ckpt_dir, "metadata.pt"), map_location=device) - - logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") - return metadata["global_step"] - - -def get_latest_checkpoint_step(config): - """Get the latest checkpoint step number.""" + # Load metadata (weights_only=False needed for older checkpoints that might contain JAX/Flax objects) + try: + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + global_step = metadata.get("global_step", latest_step) + logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") + return global_step + except Exception as e: + logging.warning(f"Failed to load metadata from checkpoint: {e}") + logging.warning("Using checkpoint step number as global step") + return latest_step + + +def get_latest_checkpoint_step(checkpoint_dir): + """Get the latest checkpoint step number from a checkpoint directory.""" checkpoint_steps = [] - for d in config.checkpoint_dir.iterdir(): + for d in checkpoint_dir.iterdir(): if d.is_dir() and d.name.isdigit(): checkpoint_steps.append(int(d.name)) @@ -386,7 +397,7 @@ def setup_memory_optimizations(model, device, enable_gradient_checkpointing=Fals logging.info(f"Cleared CUDA cache for device {device.index}") -def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False): +def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False): use_ddp, local_rank, device = setup_ddp() is_main = (not use_ddp) or (dist.get_rank() == 0) set_seed(config.seed, local_rank) @@ -397,24 +408,32 @@ def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, grad # Initialize checkpoint directory and wandb resuming = False - if config.resume: - # Check if checkpoint directory exists and has checkpoints - if config.checkpoint_dir.exists(): - latest_step = get_latest_checkpoint_step(config) + if resume: + # Find checkpoint directory based on experiment name + exp_checkpoint_dir = config.checkpoint_dir + if exp_checkpoint_dir.exists(): + latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) if latest_step is not None: resuming = True - logging.info(f"Resuming from checkpoint directory: {config.checkpoint_dir} at step {latest_step}") + logging.info(f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}") else: - raise FileNotFoundError(f"No checkpoints found in {config.checkpoint_dir} for resume") + raise FileNotFoundError(f"No checkpoints found in {exp_checkpoint_dir} for resume") else: - raise FileNotFoundError(f"Checkpoint directory {config.checkpoint_dir} does not exist for resume") + raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") elif config.overwrite and config.checkpoint_dir.exists(): import shutil shutil.rmtree(config.checkpoint_dir) logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") - # Create checkpoint directory - config.checkpoint_dir.mkdir(parents=True, exist_ok=True) + # Create checkpoint directory with experiment name + if not resuming: + # For new runs, create experiment-specific checkpoint directory + exp_checkpoint_dir = config.checkpoint_dir + exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) + logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") + else: + # For resume, checkpoint_dir is already set to the experiment directory + logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") # Initialize wandb (only on main process) if is_main: @@ -534,7 +553,7 @@ def train_loop(config: _config.TrainConfig, ckpt_save_interval: int = None, grad # Load checkpoint if resuming global_step = 0 if resuming: - global_step = load_checkpoint(model, optim, config, device, ema_model) + global_step = load_checkpoint(model, optim, config.checkpoint_dir, device, ema_model) logging.info(f"Resumed training from step {global_step}") def lr_schedule(step: int): @@ -692,6 +711,8 @@ def main(): # Parse additional command line arguments for memory optimization import argparse parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--resume", action="store_true", default=False, + help="Resume training from the latest checkpoint for the experiment (handles both PyTorch and JAX checkpoints)") parser.add_argument("--ckpt_save_interval", type=int, default=None, help="Interval for saving checkpoints (overrides config.save_interval)") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, @@ -716,6 +737,7 @@ def main(): os.environ["TORCH_DISTRIBUTED_DEBUG"] = args.ddp_debug_level train_loop(config, + resume=args.resume, ckpt_save_interval=args.ckpt_save_interval, gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=mixed_precision, diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index f76ed93..2c07776 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -208,6 +208,162 @@ def preprocess_observation( ) +def preprocess_observation_pytorch_torch_compile( + observation, + *, + train: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +): + """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. + + This function avoids complex type annotations that can cause torch.compile issues. + """ + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + for key in image_keys: + image = observation.images[key] + + # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats + # Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + image = image.permute(0, 2, 3, 1) + + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad_torch(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + image = image / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = image.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=image.device) + start_w = torch.randint(0, max_w + 1, (1,), device=image.device) + image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] + + # Resize back to original size + image = torch.nn.functional.interpolate( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode='bilinear', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = angle * torch.pi / 180.0 + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=image.device) + grid_y = torch.linspace(-1, 1, height, device=image.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + image = torch.nn.functional.grid_sample( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 + image = image * brightness_factor + + # Random contrast + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 + mean = image.mean(dim=[1, 2, 3], keepdim=True) + image = (image - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 + gray = image.mean(dim=-1, keepdim=True) + image = gray + (image - gray) * saturation_factor + + # Clamp values to [0, 1] + image = torch.clamp(image, 0, 1) + + # Back to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) + else: + out_masks[key] = observation.image_masks[key] + + # Create a simple object with the required attributes instead of using the complex Observation class + class SimpleProcessedObservation: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + return SimpleProcessedObservation( + images=out_images, + image_masks=out_masks, + state=observation.state, + tokenized_prompt=observation.tokenized_prompt, + tokenized_prompt_mask=observation.tokenized_prompt_mask, + token_ar_mask=observation.token_ar_mask, + token_loss_mask=observation.token_loss_mask, + ) + + def preprocess_observation_pytorch( observation: Observation, *, @@ -259,8 +415,9 @@ def preprocess_observation_pytorch( max_h = height - crop_height max_w = width - crop_width if max_h > 0 and max_w > 0: - start_h = torch.randint(0, max_h + 1, (1,)).item() - start_w = torch.randint(0, max_w + 1, (1,)).item() + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=image.device) + start_w = torch.randint(0, max_w + 1, (1,), device=image.device) image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] # Resize back to original size @@ -272,10 +429,11 @@ def preprocess_observation_pytorch( ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] # Random rotation (small angles) - angle = torch.rand(1).item() * 10 - 5 # Random angle between -5 and 5 degrees - if abs(angle) > 0.1: # Only rotate if angle is significant + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant # Convert to radians - angle_rad = torch.tensor(angle * torch.pi / 180.0, device=image.device) + angle_rad = angle * torch.pi / 180.0 # Create rotation matrix cos_a = torch.cos(angle_rad) @@ -309,17 +467,20 @@ def preprocess_observation_pytorch( # Color augmentations for all cameras # Random brightness - brightness_factor = 0.7 + torch.rand(1).item() * 0.6 # Random factor between 0.7 and 1.3 + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 image = image * brightness_factor # Random contrast - contrast_factor = 0.6 + torch.rand(1).item() * 0.8 # Random factor between 0.6 and 1.4 + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 mean = image.mean(dim=[1, 2, 3], keepdim=True) image = (image - mean) * contrast_factor + mean # Random saturation (convert to HSV, modify S, convert back) # For simplicity, we'll just apply a random scaling to the color channels - saturation_factor = 0.5 + torch.rand(1).item() * 1.0 # Random factor between 0.5 and 1.5 + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 gray = image.mean(dim=-1, keepdim=True) image = gray + (image - gray) * saturation_factor diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 13f42b2..73828e7 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -105,7 +105,7 @@ def __init__(self, config): torch.set_float32_matmul_precision('high') self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") - self.forward = torch.compile(self.forward, mode="reduce-overhead") + #self.forward = torch.compile(self.forward, mode="reduce-overhead") def sample_noise(self, shape, device): noise = torch.normal( @@ -234,7 +234,12 @@ def embed_suffix(self, state, noisy_actions, timestep): def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - observation = _model.preprocess_observation_pytorch(observation, train=True) + # # Use torch.compile-compatible preprocessing if we're in a compiled context + # if torch._dynamo.is_compiling(): + # observation = _model.preprocess_observation_pytorch_torch_compile(observation, train=True) + # else: + # observation = _model.preprocess_observation_pytorch(observation, train=True) + observation = _model.preprocess_observation_pytorch_torch_compile(observation, train=True) images = list(observation.images.values()) img_masks = list(observation.image_masks.values()) lang_tokens = observation.tokenized_prompt From 3c9084a9638419c42e876e1cd09d0915ff40aa9a Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 14:31:07 -0700 Subject: [PATCH 13/32] add gradient checkpointing --- scripts/train_pytorch.py | 8 +++ src/openpi/models_pytorch/pi0_pytorch.py | 66 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index cdaeb17..822d8de 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -499,6 +499,14 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Apply memory optimizations setup_memory_optimizations(model, device, enable_gradient_checkpointing) + # Log gradient checkpointing status if enabled + if enable_gradient_checkpointing and is_main: + if hasattr(model, 'get_gradient_checkpointing_status'): + status = model.get_gradient_checkpointing_status() + logging.info(f"Gradient checkpointing status: {status}") + else: + logging.info("Gradient checkpointing enabled but status check not available") + # Check model parameters for debugging if is_main: check_model_parameters(model, device) diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 73828e7..5f13a66 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -1,4 +1,5 @@ import math +import logging import torch from torch import Tensor @@ -106,6 +107,71 @@ def __init__(self, config): torch.set_float32_matmul_precision('high') self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") #self.forward = torch.compile(self.forward, mode="reduce-overhead") + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + + # Enable gradient checkpointing in the underlying models + if hasattr(self.paligemma_with_expert, 'paligemma'): + if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing in PaliGemma language model") + + if hasattr(self.paligemma_with_expert, 'gemma_expert'): + if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing in Gemma expert model") + + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + + # Disable gradient checkpointing in the underlying models + if hasattr(self.paligemma_with_expert, 'paligemma'): + if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + + if hasattr(self.paligemma_with_expert, 'gemma_expert'): + if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def is_gradient_checkpointing_enabled(self): + """Check if gradient checkpointing is enabled.""" + return self.gradient_checkpointing_enabled + + def get_gradient_checkpointing_status(self): + """Get detailed gradient checkpointing status of underlying models.""" + status = { + 'main_model': self.gradient_checkpointing_enabled, + 'paligemma_language_model': False, + 'gemma_expert_model': False + } + + if hasattr(self.paligemma_with_expert, 'paligemma'): + if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): + status['paligemma_language_model'] = getattr( + self.paligemma_with_expert.paligemma.language_model, + 'gradient_checkpointing', + False + ) + + if hasattr(self.paligemma_with_expert, 'gemma_expert'): + if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): + status['gemma_expert_model'] = getattr( + self.paligemma_with_expert.gemma_expert.model, + 'gradient_checkpointing', + False + ) + + return status def sample_noise(self, shape, device): noise = torch.normal( From e09ee9819ef63a4fa880d3402e450719718e2c92 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 22 Aug 2025 14:35:09 -0700 Subject: [PATCH 14/32] fix gradient checkpointing --- scripts/train_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 822d8de..c8475a7 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -731,7 +731,7 @@ def main(): help="Disable mixed precision training") parser.add_argument("--max_memory_usage", type=float, default=None, help="Maximum GPU memory usage in GB (default: None, auto-detect)") - parser.add_argument("--enable_gradient_checkpointing", action="store_true", default=False, + parser.add_argument("--gradckpt", action="store_true", default=False, help="Enable gradient checkpointing for memory optimization") parser.add_argument("--ddp_debug_level", type=str, default="INFO", choices=["INFO", "DETAIL", "OFF"], help="DDP debugging level (default: INFO)") @@ -750,7 +750,7 @@ def main(): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=mixed_precision, max_memory_usage=args.max_memory_usage, - enable_gradient_checkpointing=args.enable_gradient_checkpointing) + enable_gradient_checkpointing=args.gradckpt) if __name__ == "__main__": From 615149e12f90b12a5dfb60dd092f037fe6aa810c Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sun, 24 Aug 2025 23:27:21 -0700 Subject: [PATCH 15/32] batch size 16 working --- scripts/train_pytorch.py | 336 +++++++++++++++++++-- src/openpi/models_pytorch/gemma_pytorch.py | 106 +++++-- src/openpi/models_pytorch/pi0_pytorch.py | 226 +++++++++++--- src/openpi/training/config.py | 6 +- 4 files changed, 577 insertions(+), 97 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index c8475a7..438f28e 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -56,11 +56,13 @@ Checkpoint Parameters: - --ckpt_save_interval: Override the checkpoint save interval from config (e.g., --save_interval 500) - --resume: Resume training from the latest checkpoint in the checkpoint directory +- --cleanup_checkpoints: Clean up corrupted checkpoints during resume (keeps last 3 valid ones) - --overwrite: Overwrite existing checkpoint directory (cannot be used with --resume) Memory Optimization Parameters: - --gradient_accumulation_steps: Number of steps to accumulate gradients (default: 1) - --mixed_precision: Enable mixed precision training (default: True) - --max_memory_usage: Maximum GPU memory usage in GB (default: None, auto-detect) +- --gradckpt: Enable gradient checkpointing for memory optimization Notes - The global batch size must be divisible by world size (number of processes). - The data pipeline and transforms are identical to the JAX version and are controlled @@ -80,6 +82,7 @@ import os import platform import time +import gc from dataclasses import dataclass from typing import Any, Dict, Tuple @@ -287,25 +290,49 @@ def load_checkpoint(model, optimizer, checkpoint_dir, device, ema_model=None): latest_step = max(checkpoint_steps) ckpt_dir = checkpoint_dir / f"{latest_step}" - # Load model state - model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device) - (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) + # Load model state with error handling + try: + model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) + (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) + logging.info(f"Successfully loaded model state from step {latest_step}") + except Exception as e: + logging.error(f"Failed to load model state from step {latest_step}: {e}") + raise RuntimeError(f"Model checkpoint corrupted at step {latest_step}. Cannot resume training.") - # Load optimizer state - optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device) - optimizer.load_state_dict(optimizer_state_dict) + # Load optimizer state with error handling and fallback + optimizer_loaded = False + try: + optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) + optimizer.load_state_dict(optimizer_state_dict) + optimizer_loaded = True + logging.info(f"Successfully loaded optimizer state from step {latest_step}") + except Exception as e: + logging.warning(f"Failed to load optimizer state from step {latest_step}: {e}") + logging.warning("Optimizer state corrupted. Will continue with fresh optimizer state.") + # Reset optimizer to fresh state + for param_group in optimizer.param_groups: + param_group['lr'] = param_group.get('lr', 1e-4) # Use default LR or current LR + optimizer.zero_grad() + optimizer_loaded = False # Load EMA state if available + ema_loaded = False if ema_model is not None and (ckpt_dir / "ema_model.pt").exists(): - ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device) - ema_model.load_state_dict(ema_state_dict) - logging.info(f"Loaded EMA state from checkpoint") + try: + ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device, weights_only=False) + ema_model.load_state_dict(ema_state_dict) + ema_loaded = True + logging.info(f"Successfully loaded EMA state from step {latest_step}") + except Exception as e: + logging.warning(f"Failed to load EMA state from step {latest_step}: {e}") + logging.warning("EMA state corrupted. Will continue without EMA.") + ema_loaded = False # Load metadata (weights_only=False needed for older checkpoints that might contain JAX/Flax objects) try: metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) global_step = metadata.get("global_step", latest_step) - logging.info(f"Loaded checkpoint from step {latest_step} -> {ckpt_dir}") + logging.info(f"Successfully loaded metadata from step {latest_step}") return global_step except Exception as e: logging.warning(f"Failed to load metadata from checkpoint: {e}") @@ -323,6 +350,110 @@ def get_latest_checkpoint_step(checkpoint_dir): return max(checkpoint_steps) if checkpoint_steps else None +def validate_checkpoint_integrity(checkpoint_dir, step): + """Validate that a checkpoint at the given step is complete and uncorrupted.""" + ckpt_dir = checkpoint_dir / f"{step}" + + required_files = ["pytorch_model.pt", "optimizer.pt", "metadata.pt"] + optional_files = ["ema_model.pt"] + + # Check if all required files exist + for file_name in required_files: + file_path = ckpt_dir / file_name + if not file_path.exists(): + logging.warning(f"Required checkpoint file missing: {file_path}") + return False + + # Try to validate file integrity by attempting to load them + try: + # Test model file + device = torch.device("cpu") # Use CPU for validation to avoid GPU memory issues + model_state = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) + if not isinstance(model_state, dict): + logging.warning(f"Model checkpoint file corrupted at step {step}") + return False + + # Test optimizer file + optimizer_state = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) + if not isinstance(optimizer_state, dict): + logging.warning(f"Optimizer checkpoint file corrupted at step {step}") + return False + + # Test metadata file + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + if not isinstance(metadata, dict) or "global_step" not in metadata: + logging.warning(f"Metadata checkpoint file corrupted at step {step}") + return False + + logging.info(f"Checkpoint at step {step} validated successfully") + return True + + except Exception as e: + logging.warning(f"Checkpoint validation failed at step {step}: {e}") + return False + + +def find_latest_valid_checkpoint(checkpoint_dir): + """Find the latest checkpoint that passes integrity validation.""" + checkpoint_steps = [] + for d in checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + if not checkpoint_steps: + return None + + # Sort steps in descending order to check latest first + checkpoint_steps.sort(reverse=True) + + for step in checkpoint_steps: + if validate_checkpoint_integrity(checkpoint_dir, step): + return step + + logging.error("No valid checkpoints found in directory") + return None + + +def cleanup_corrupted_checkpoints(checkpoint_dir, keep_last_n=3): + """Clean up corrupted checkpoints, keeping only the last N valid ones.""" + checkpoint_steps = [] + for d in checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + if not checkpoint_steps: + return + + # Sort steps in descending order + checkpoint_steps.sort(reverse=True) + + valid_checkpoints = [] + corrupted_checkpoints = [] + + # Validate all checkpoints + for step in checkpoint_steps: + if validate_checkpoint_integrity(checkpoint_dir, step): + valid_checkpoints.append(step) + else: + corrupted_checkpoints.append(step) + + # Keep only the last N valid checkpoints + checkpoints_to_keep = valid_checkpoints[:keep_last_n] + checkpoints_to_remove = valid_checkpoints[keep_last_n:] + corrupted_checkpoints + + # Remove old valid checkpoints and all corrupted ones + for step in checkpoints_to_remove: + checkpoint_path = checkpoint_dir / f"{step}" + try: + import shutil + shutil.rmtree(checkpoint_path) + logging.info(f"Removed checkpoint at step {step}") + except Exception as e: + logging.warning(f"Failed to remove checkpoint at step {step}: {e}") + + logging.info(f"Checkpoint cleanup complete. Kept {len(checkpoints_to_keep)} valid checkpoints: {checkpoints_to_keep}") + + def debug_unused_parameters(model, device): """Debug function to identify unused parameters in the model.""" if isinstance(model, DDP): @@ -368,14 +499,41 @@ def check_model_parameters(model, device): if unused_params: logging.warning(f"Found {len(unused_params)} parameters that might be unused:") - for name in unused_params[:10]: # Show first 10 + for name in unused_params: # Show first 10 logging.warning(f" - {name}") - if len(unused_params) > 10: - logging.warning(f" ... and {len(unused_params) - 10} more") + # if len(unused_params) > 10: + # logging.warning(f" ... and {len(unused_params) - 10} more") + + +def log_memory_usage(device, step, phase="unknown"): + """Log detailed memory usage information.""" + if not torch.cuda.is_available(): + return + + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 + memory_reserved = torch.cuda.memory_reserved(device) / 1e9 + memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) + memory_free = memory_free / 1e9 + + # Get more detailed memory info + memory_stats = torch.cuda.memory_stats(device) + max_memory_allocated = memory_stats.get('allocated_bytes.all.peak', 0) / 1e9 + max_memory_reserved = memory_stats.get('reserved_bytes.all.peak', 0) / 1e9 + + # Get DDP info if available + ddp_info = "" + if dist.is_initialized(): + ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" + + logging.info(f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}") def setup_memory_optimizations(model, device, enable_gradient_checkpointing=False): """Setup memory optimization techniques for the model.""" + # Set memory optimization environment variables + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + os.environ["CUDA_LAUNCH_BLOCKING"] = "0" + if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): model.gradient_checkpointing_enable() logging.info("Enabled gradient checkpointing for memory optimization") @@ -388,16 +546,18 @@ def setup_memory_optimizations(model, device, enable_gradient_checkpointing=Fals # Set memory efficient settings if torch.cuda.is_available(): # Enable memory efficient algorithms - torch.backends.cudnn.benchmark = True - torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = False # Disable for memory efficiency + torch.backends.cudnn.deterministic = True # Enable for memory efficiency # Set memory fraction if needed if device.index is not None: torch.cuda.empty_cache() logging.info(f"Cleared CUDA cache for device {device.index}") + + logging.info("Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to reduce memory fragmentation") -def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False): +def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False, cleanup_checkpoints: bool = False): use_ddp, local_rank, device = setup_ddp() is_main = (not use_ddp) or (dist.get_rank() == 0) set_seed(config.seed, local_rank) @@ -412,12 +572,18 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Find checkpoint directory based on experiment name exp_checkpoint_dir = config.checkpoint_dir if exp_checkpoint_dir.exists(): - latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) + # Use validation to find the latest working checkpoint + latest_step = find_latest_valid_checkpoint(exp_checkpoint_dir) if latest_step is not None: resuming = True logging.info(f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}") + + # Clean up corrupted checkpoints if requested + if cleanup_checkpoints and is_main: + logging.info("Cleaning up corrupted checkpoints...") + cleanup_corrupted_checkpoints(exp_checkpoint_dir, keep_last_n=3) else: - raise FileNotFoundError(f"No checkpoints found in {exp_checkpoint_dir} for resume") + raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") else: raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") elif config.overwrite and config.checkpoint_dir.exists(): @@ -449,10 +615,8 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte effective_batch_size = config.batch_size // (dist.get_world_size() if use_ddp else 1) # Memory-efficient data loading with reduced pin_memory for large datasets - pin_memory = True - if effective_batch_size > 16: # Reduce pin_memory for large batches - pin_memory = False - logging.info("Disabled pin_memory for large batch size to reduce memory usage") + pin_memory = False # Disable pin_memory to reduce memory usage + logging.info("Disabled pin_memory to reduce memory usage") loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) @@ -480,6 +644,8 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Reset the loader iterator loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) + # Test gradient checkpointing with a small forward pass (moved to after model creation) + # Build model if not isinstance(config.model, Pi0Config): # Convert dataclass to Pi0Config if needed @@ -499,17 +665,77 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Apply memory optimizations setup_memory_optimizations(model, device, enable_gradient_checkpointing) + # Log initial memory usage after model creation + if is_main and torch.cuda.is_available(): + log_memory_usage(device, 0, "after_model_creation") + # Log gradient checkpointing status if enabled if enable_gradient_checkpointing and is_main: if hasattr(model, 'get_gradient_checkpointing_status'): status = model.get_gradient_checkpointing_status() logging.info(f"Gradient checkpointing status: {status}") + + # Verify that gradient checkpointing is actually enabled + if hasattr(model, 'is_gradient_checkpointing_enabled'): + is_enabled = model.is_gradient_checkpointing_enabled() + logging.info(f"Gradient checkpointing is enabled: {is_enabled}") + + # Check if we're in training mode + logging.info(f"Model training mode: {model.training}") + + # Verify the underlying models have gradient checkpointing enabled + if hasattr(model, 'paligemma_with_expert'): + if hasattr(model.paligemma_with_expert, 'paligemma'): + if hasattr(model.paligemma_with_expert.paligemma, 'language_model'): + paligemma_gc = getattr(model.paligemma_with_expert.paligemma.language_model, 'gradient_checkpointing', False) + logging.info(f"PaliGemma language model gradient checkpointing: {paligemma_gc}") + + if hasattr(model.paligemma_with_expert.paligemma, 'vision_tower'): + vision_gc = getattr(model.paligemma_with_expert.paligemma.vision_tower, 'gradient_checkpointing', False) + logging.info(f"PaliGemma vision tower gradient checkpointing: {vision_gc}") + + if hasattr(model.paligemma_with_expert, 'gemma_expert'): + if hasattr(model.paligemma_with_expert.gemma_expert, 'model'): + gemma_gc = getattr(model.paligemma_with_expert.gemma_expert.model, 'gradient_checkpointing', False) + logging.info(f"Gemma expert model gradient checkpointing: {gemma_gc}") else: logging.info("Gradient checkpointing enabled but status check not available") - # Check model parameters for debugging - if is_main: - check_model_parameters(model, device) + # Test gradient checkpointing with a small forward pass + if is_main and enable_gradient_checkpointing: + logging.info("Testing gradient checkpointing with a small forward pass...") + try: + # Create a small test batch + test_batch = next(iter(loader)) + test_batch = batch_to_torch(test_batch, device) + test_actions = test_batch["actions"] + + # Record memory before forward pass + if torch.cuda.is_available(): + memory_before = torch.cuda.memory_allocated(device) / 1e9 + logging.info(f"Memory before test forward pass: {memory_before:.2f}GB") + + # Do a test forward pass + with torch.no_grad(): + test_observation = _model.Observation.from_dict(test_batch) + test_losses = model(test_observation, test_actions) + + # Record memory after forward pass + if torch.cuda.is_available(): + memory_after = torch.cuda.memory_allocated(device) / 1e9 + logging.info(f"Memory after test forward pass: {memory_after:.2f}GB") + logging.info(f"Memory difference: {memory_after - memory_before:.2f}GB") + + # Clear test data + del test_batch, test_actions, test_observation, test_losses + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + logging.info("Gradient checkpointing test completed successfully") + except Exception as e: + logging.warning(f"Gradient checkpointing test failed: {e}") + logging.warning("Continuing with training...") if use_ddp: # Enable unused parameter detection to handle cases where some parameters don't participate in loss @@ -574,6 +800,17 @@ def lr_schedule(step: int): # Enable mixed precision training for memory optimization scaler = torch.amp.GradScaler(enabled=mixed_precision and torch.cuda.is_available()) + + # Set memory efficient settings + if torch.cuda.is_available(): + # Enable memory efficient algorithms + torch.backends.cudnn.benchmark = False # Disable for memory efficiency + torch.backends.cudnn.deterministic = True # Enable for memory efficiency + + # Set memory fraction if needed + if device.index is not None: + torch.cuda.empty_cache() + logging.info(f"Cleared CUDA cache for device {device.index}") model.train() start_time = time.time() @@ -590,6 +827,9 @@ def lr_schedule(step: int): # Training loop - iterate until we reach num_train_steps pbar = tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None + # Check model parameters after first few steps when gradients are available + parameters_checked = False + while global_step < config.num_train_steps: if use_ddp: sampler.set_epoch(global_step // len(loader)) @@ -619,6 +859,14 @@ def lr_schedule(step: int): losses = torch.tensor(losses, device=device, dtype=torch.float32) loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation + + # Debug gradient checkpointing on first few steps + if global_step < 5 and is_main: + if hasattr(model, 'is_gradient_checkpointing_enabled'): + gc_enabled = model.is_gradient_checkpointing_enabled() + logging.info(f"Step {global_step}: Gradient checkpointing enabled: {gc_enabled}") + if torch.cuda.is_available(): + log_memory_usage(device, global_step, "after_forward") except RuntimeError as e: if "Expected to have finished reduction" in str(e) or "did not receive grad" in str(e): logging.error(f"DDP error on rank {dist.get_rank() if use_ddp else 0}: {e}") @@ -630,6 +878,16 @@ def lr_schedule(step: int): # Backward pass with gradient scaling scaler.scale(loss).backward() + + # Aggressive memory cleanup after backward pass + if torch.cuda.is_available(): + # Clear intermediate activations that might still be in memory + torch.cuda.empty_cache() + gc.collect() + + # Log memory usage after backward pass for debugging + if global_step < 5 and is_main: + log_memory_usage(device, global_step, "after_backward") # Gradient accumulation logic if (global_step + 1) % gradient_accumulation_steps == 0: @@ -641,6 +899,12 @@ def lr_schedule(step: int): scaler.step(optim) scaler.update() optim.zero_grad(set_to_none=True) + + # Clear gradients more aggressively + for param in model.parameters(): + if param.grad is not None: + param.grad.detach_() + param.grad = None # Update EMA if enabled if ema_model is not None: @@ -654,6 +918,11 @@ def lr_schedule(step: int): logging.warning(f"Failed to update EMA model: {e}") # Continue training without EMA update + # # Check model parameters after first few steps when gradients are available + # if not parameters_checked and global_step >= 16510 and is_main: + # check_model_parameters(model, device) + # parameters_checked = True + # Collect stats (only on accumulation steps) if (global_step + 1) % gradient_accumulation_steps == 0 and is_main: infos.append({ @@ -671,7 +940,7 @@ def lr_schedule(step: int): logging.info(f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s") # Log to wandb - if config.wandb_enabled: + if config.wandb_enabled and len(infos) > 1: wandb.log({ "loss": avg_loss, "learning_rate": avg_lr, @@ -698,8 +967,18 @@ def lr_schedule(step: int): # Memory cleanup after each batch del batch, actions, observation, losses, loss + + # More aggressive memory cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() + # Force garbage collection + gc.collect() + + # Log memory usage for debugging gradient checkpointing + if is_main and global_step % 100 == 0: + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 + memory_reserved = torch.cuda.memory_reserved(device) / 1e9 + logging.info(f"Step {global_step}: GPU memory allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB") # Close progress bar if pbar is not None: @@ -721,6 +1000,8 @@ def main(): parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--resume", action="store_true", default=False, help="Resume training from the latest checkpoint for the experiment (handles both PyTorch and JAX checkpoints)") + parser.add_argument("--cleanup_checkpoints", action="store_true", default=False, + help="Clean up corrupted checkpoints during resume (keeps last 3 valid checkpoints)") parser.add_argument("--ckpt_save_interval", type=int, default=None, help="Interval for saving checkpoints (overrides config.save_interval)") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, @@ -750,7 +1031,8 @@ def main(): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=mixed_precision, max_memory_usage=args.max_memory_usage, - enable_gradient_checkpointing=args.gradckpt) + enable_gradient_checkpointing=True, + cleanup_checkpoints=args.cleanup_checkpoints) if __name__ == "__main__": diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 5a27cfa..5737a80 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -55,7 +55,8 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" + vlm_config_hf.vision_config.torch_dtype = "bfloat16" + vlm_config_hf.vision_config._attn_implementation = "flash_attention_2" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, @@ -135,7 +136,38 @@ def forward( else: models = [self.paligemma.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers - for layer_idx in range(num_layers): + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, 'gradient_checkpointing') and + self.gemma_expert.model.gradient_checkpointing and + self.training + ) or ( + hasattr(self, 'gradient_checkpointing') and + self.gradient_checkpointing and + self.training + ) + + # Force enable gradient checkpointing if we're in training mode and the model supports it + if self.training and hasattr(self.gemma_expert.model, 'gradient_checkpointing'): + if not self.gemma_expert.model.gradient_checkpointing: + print("Forcing gradient checkpointing to be enabled for Gemma expert model") + self.gemma_expert.model.gradient_checkpointing = True + use_gradient_checkpointing = True + + # Debug gradient checkpointing status + if hasattr(self, '_debug_gc_printed') and not self._debug_gc_printed: + print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}") + print(f"Model training mode: {self.training}") + print(f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}") + if hasattr(self.gemma_expert.model, 'gradient_checkpointing'): + print(f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}") + self._debug_gc_printed = True + + # Define the complete layer computation function for gradient checkpointing + def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): + models = [self.paligemma.language_model, self.gemma_expert.model] + query_states = [] key_states = [] value_states = [] @@ -143,9 +175,6 @@ def forward( for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) - # hidden_states = hidden_states.to(dtype=torch.bfloat16) - # if gate is not None: - # gate = gate.to(dtype=torch.bfloat16) gates.append(gate) input_shape = hidden_states.shape[:-1] @@ -158,42 +187,36 @@ def forward( key_states.append(key_state) value_states.append(value_state) - # B,L,H,D with L sequence length, H number of heads, D head dim - # concatenate on the number of embeddings/tokens + # Concatenate and process attention query_states = torch.cat(query_states, dim=2) key_states = torch.cat(key_states, dim=2) value_states = torch.cat(value_states, dim=2) - # query_states = apply_rope(query_states, position_ids) - # key_states = apply_rope(key_states, position_ids) - - # query_states = query_states.transpose(1, 2) - # key_states = key_states.transpose(1, 2) - # value_states = value_states.transpose(1, 2) - dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) batch_size = query_states.shape[0] scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling + + # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( self.paligemma.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling ) - #att_output = att_output.to(dtype=torch.bfloat16) - att_output = att_output.reshape(batch_size, -1, 1 * 8 * layer.self_attn.head_dim) - + # Get head_dim from the current layer, not from the model + head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) - # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) + # Process layer outputs outputs_embeds = [] - start = 0 + start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] - end = start + hidden_states.shape[1] + end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) - out_emb = layer.self_attn.o_proj(att_output[:, start:end]) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) @@ -205,14 +228,43 @@ def forward( # second residual out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) outputs_embeds.append(out_emb) - start = end - inputs_embeds = outputs_embeds + start_pos = end_pos + + return outputs_embeds + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, + use_reentrant=False, + preserve_rng_state=False + ) + else: + inputs_embeds = compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond) + + # Old code removed - now using compute_layer_complete function above # final norm - outputs_embeds = [] - for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) - outputs_embeds.append(out_emb) + # Define final norm computation function for gradient checkpointing + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, adarms_cond, + use_reentrant=False, + preserve_rng_state=False + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) prefix_output = outputs_embeds[0] suffix_output = outputs_embeds[1] diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 5f13a66..8dff228 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -119,14 +119,19 @@ def gradient_checkpointing_enable(self): if hasattr(self.paligemma_with_expert, 'paligemma'): if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - logging.info("Enabled gradient checkpointing in PaliGemma language model") + print("Enabled gradient checkpointing in PaliGemma language model") + + # Enable gradient checkpointing in the vision model (SigLIP) + if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + print("Enabled gradient checkpointing in PaliGemma vision tower (SigLIP)") if hasattr(self.paligemma_with_expert, 'gemma_expert'): if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True - logging.info("Enabled gradient checkpointing in Gemma expert model") + print("Enabled gradient checkpointing in Gemma expert model") - logging.info("Enabled gradient checkpointing for PI0Pytorch model") + print("Enabled gradient checkpointing for PI0Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" @@ -136,6 +141,10 @@ def gradient_checkpointing_disable(self): if hasattr(self.paligemma_with_expert, 'paligemma'): if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + + # Disable gradient checkpointing in the vision model (SigLIP) + if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False if hasattr(self.paligemma_with_expert, 'gemma_expert'): if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): @@ -152,6 +161,7 @@ def get_gradient_checkpointing_status(self): status = { 'main_model': self.gradient_checkpointing_enabled, 'paligemma_language_model': False, + 'paligemma_vision_model': False, 'gemma_expert_model': False } @@ -162,6 +172,13 @@ def get_gradient_checkpointing_status(self): 'gradient_checkpointing', False ) + + if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): + status['paligemma_vision_model'] = getattr( + self.paligemma_with_expert.paligemma.vision_tower, + 'gradient_checkpointing', + False + ) if hasattr(self.paligemma_with_expert, 'gemma_expert'): if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): @@ -198,25 +215,64 @@ def embed_prefix( pad_masks = [] att_masks = [] - for ( - img, - img_mask, - ) in zip(images, img_masks): - img_emb = self.paligemma_with_expert.embed_image(img) - - bsize, num_img_embs = img_emb.shape[:2] - img_mask = img_mask[:, None].expand(bsize, num_img_embs) - - embs.append(img_emb) - pad_masks.append(img_mask) + # Apply gradient checkpointing to image embedding if enabled + if self.gradient_checkpointing_enabled and self.training: + for ( + img, + img_mask, + ) in zip(images, img_masks): + # Use checkpoint for image embedding + def checkpointed_image_embed(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = torch.utils.checkpoint.checkpoint( + checkpointed_image_embed, + img, + use_reentrant=False, + preserve_rng_state=False + ) - # Create attention masks so that image tokens attend to each other - att_masks += [0] * num_img_embs + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + embs.append(img_emb) + pad_masks.append(img_mask) - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + else: + for ( + img, + img_mask, + ) in zip(images, img_masks): + img_emb = self.paligemma_with_expert.embed_image(img) + + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) + + embs.append(img_emb) + pad_masks.append(img_mask) + + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + + # Apply gradient checkpointing to language embedding if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_lang_embed(lang_tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = torch.utils.checkpoint.checkpoint( + checkpointed_lang_embed, + lang_tokens, + use_reentrant=False, + preserve_rng_state=False + ) + else: + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) embs.append(lang_emb) pad_masks.append(lang_masks) @@ -243,8 +299,20 @@ def embed_suffix(self, state, noisy_actions, timestep): att_masks = [] if not self.pi05: - # Embed state - state_emb = self.state_proj(state) + # Embed state with gradient checkpointing if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_state_proj(state): + return self.state_proj(state) + + state_emb = torch.utils.checkpoint.checkpoint( + checkpointed_state_proj, + state, + use_reentrant=False, + preserve_rng_state=False + ) + else: + state_emb = self.state_proj(state) + embs.append(state_emb[:, None, :]) bsize = state_emb.shape[0] dtype = state_emb.dtype @@ -262,22 +330,66 @@ def embed_suffix(self, state, noisy_actions, timestep): ) time_emb = time_emb.type(dtype=timestep.dtype) - # Fuse timestep + action information using an MLP - action_emb = self.action_in_proj(noisy_actions) + # Fuse timestep + action information using an MLP with gradient checkpointing if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_action_proj(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = torch.utils.checkpoint.checkpoint( + checkpointed_action_proj, + noisy_actions, + use_reentrant=False, + preserve_rng_state=False + ) + else: + action_emb = self.action_in_proj(noisy_actions) if not self.pi05: time_emb = time_emb[:, None, :].expand_as(action_emb) action_time_emb = torch.cat([action_emb, time_emb], dim=2) - action_time_emb = self.action_time_mlp_in(action_time_emb) - action_time_emb = F.silu(action_time_emb) # swish == silu - action_time_emb = self.action_time_mlp_out(action_time_emb) + + # Apply gradient checkpointing to MLP layers if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_mlp(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) # swish == silu + x = self.action_time_mlp_out(x) + return x + + action_time_emb = torch.utils.checkpoint.checkpoint( + checkpointed_mlp, + action_time_emb, + use_reentrant=False, + preserve_rng_state=False + ) + else: + action_time_emb = self.action_time_mlp_in(action_time_emb) + action_time_emb = F.silu(action_time_emb) # swish == silu + action_time_emb = self.action_time_mlp_out(action_time_emb) + adarms_cond = None else: - # time MLP (for adaRMS) - time_emb = self.time_mlp_in(time_emb) - time_emb = F.silu(time_emb) # swish == silu - time_emb = self.time_mlp_out(time_emb) - time_emb = F.silu(time_emb) + # time MLP (for adaRMS) with gradient checkpointing if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_time_mlp(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) # swish == silu + x = self.time_mlp_out(x) + x = F.silu(x) + return x + + time_emb = torch.utils.checkpoint.checkpoint( + checkpointed_time_mlp, + time_emb, + use_reentrant=False, + preserve_rng_state=False + ) + else: + time_emb = self.time_mlp_in(time_emb) + time_emb = F.silu(time_emb) # swish == silu + time_emb = self.time_mlp_out(time_emb) + time_emb = F.silu(time_emb) + action_time_emb = action_emb adarms_cond = time_emb @@ -336,18 +448,52 @@ def forward(self, observation, actions, noise=None, time=None) -> Tensor: att_2d_masks_4d = att_2d_masks[:, None, :, :] att_2d_masks_4d = torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks_4d, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond] - ) + # Apply gradient checkpointing if enabled + if self.gradient_checkpointing_enabled and self.training: + # Use torch.utils.checkpoint.checkpoint for the expensive forward pass + def checkpointed_forward(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond] + ) + return suffix_out + + suffix_out = torch.utils.checkpoint.checkpoint( + checkpointed_forward, + prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond, + use_reentrant=False, # More memory efficient + preserve_rng_state=False # More memory efficient + ) + else: + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond] + ) + suffix_out = suffix_out[:, -self.config.action_horizon :] suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) + # Apply gradient checkpointing to final action projection if enabled + if self.gradient_checkpointing_enabled and self.training: + def checkpointed_action_out_proj(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = torch.utils.checkpoint.checkpoint( + checkpointed_action_out_proj, + suffix_out, + use_reentrant=False, + preserve_rng_state=False + ) + else: + v_t = self.action_out_proj(suffix_out) losses = F.mse_loss(u_t, v_t, reduction="none") return losses diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 6235e32..f866bee 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -746,7 +746,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=64, + batch_size=256, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -756,7 +756,7 @@ def __post_init__(self) -> None: optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", - num_train_steps=30_000, + num_train_steps=60_000, ), # # Fine-tuning Aloha configs. @@ -952,4 +952,4 @@ def get_config(config_name: str) -> TrainConfig: closest_str = f" Did you mean '{closest[0]}'? " if closest else "" raise ValueError(f"Config '{config_name}' not found.{closest_str}") - return _CONFIGS_DICT[config_name] \ No newline at end of file + return _CONFIGS_DICT[config_name] From 4fc7766aa9b3052c58f53ffd3ce027b5b15f37a4 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Mon, 25 Aug 2025 20:58:32 -0700 Subject: [PATCH 16/32] bs 1024 (2 nodes) working --- src/openpi/models_pytorch/gemma_pytorch.py | 4 ++-- src/openpi/training/config.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 5737a80..3d70ca0 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -55,8 +55,8 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "bfloat16" - vlm_config_hf.vision_config._attn_implementation = "flash_attention_2" + vlm_config_hf.vision_config.torch_dtype = "float32" + # vlm_config_hf.vision_config._attn_implementation = "flash_attention_2" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index f866bee..e1ec12c 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -746,17 +746,17 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=256, + batch_size=1024, lr_schedule=_optimizer.CosineDecaySchedule( - warmup_steps=10_000, - peak_lr=5e-5, + warmup_steps=2_500, + peak_lr=1e-4, decay_steps=1_000_000, - decay_lr=5e-5, + decay_lr=1e-4, ), optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", - num_train_steps=60_000, + num_train_steps=6_500, ), # # Fine-tuning Aloha configs. From 44dba744da9f4f3f89c6a908faa0619104fe598c Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Tue, 26 Aug 2025 01:40:09 -0700 Subject: [PATCH 17/32] clean up pytorch finetuning --- examples/inference_example.py | 6 +- scripts/train_pytorch.py | 626 ++++----------------- scripts/train_single_example.py | 361 ++++-------- src/openpi/models/model.py | 154 +---- src/openpi/models_pytorch/gemma_pytorch.py | 28 - src/openpi/models_pytorch/pi0_pytorch.py | 329 +++-------- src/openpi/training/config.py | 8 +- src/openpi/training/data_loader.py | 3 +- 8 files changed, 322 insertions(+), 1193 deletions(-) diff --git a/examples/inference_example.py b/examples/inference_example.py index 74eaa4d..72e9d76 100644 --- a/examples/inference_example.py +++ b/examples/inference_example.py @@ -5,13 +5,13 @@ This demonstrates the basic usage patterns for both implementations. pi0_droid -python examples/inference_example.py --model_name pi0_droid --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_droid --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch +python examples/inference_example.py --model_name pi0_droid --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch pi0_aloha_sim -python examples/inference_example.py --model_name pi0_aloha_sim --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch +python examples/inference_example.py --model_name pi0_aloha_sim --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch pi05_droid -python examples/inference_example.py --model_name pi05_droid --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch +python examples/inference_example.py --model_name pi05_droid --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch """ import argparse diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 438f28e..e6d4a4e 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -1,24 +1,12 @@ """ -PyTorch training entrypoint for PI0 with multi-GPU and multi-node (DDP) support. +PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs entirely in PyTorch using the `PI0Pytorch` model and your existing config/data pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. -Key features -- Uses the same TrainConfig/tyro CLI as the JAX script (see available configs in - `src/openpi/training/config.py`). -- Supports multi-GPU and multi-node training via DistributedDataParallel (DDP). -- Cosine LR with warmup (parameters read from the selected config). -- AdamW optimizer and gradient clipping. -- Comprehensive checkpoint saving and resume mechanism with configurable intervals. -- Checkpoints saved on rank 0 to `config.checkpoint_dir//` containing model, optimizer, and metadata. -- Memory optimizations: mixed precision training, gradient accumulation, and efficient data handling. -Requirements -- PyTorch >= 2.0, torch.distributed (NCCL for CUDA, Gloo for CPU). -- Multiple GPUs for DDP (optional). -- Network connectivity between nodes for multi-node training. + Usage Single GPU: - python scripts/train_pytorch.py --exp_name --ckpt_save_interval + python scripts/train_pytorch.py --exp_name --save_interval Example: python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint @@ -28,78 +16,39 @@ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume Multi-Node Training: - # On master node (node 0): - torchrun --nnodes= --nproc_per_node= --rdzv_id= --rdzv_backend=c10d --rdzv_endpoint=: scripts/train_pytorch.py --exp_name - - # On worker nodes (node 1, 2, ...): - torchrun --nnodes= --nproc_per_node= --rdzv_id= --rdzv_backend=c10d --rdzv_endpoint=: scripts/train_pytorch.py --exp_name - - Example (2 nodes, 4 GPUs each): - # Master node (192.168.1.100): - torchrun --nnodes=2 --nproc_per_node=4 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=192.168.1.100:29400 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_multi_node - - # Worker node (192.168.1.101): - torchrun --nnodes=2 --nproc_per_node=4 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=192.168.1.100:29400 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_multi_node -Multi-Node Setup Requirements: -1. Network connectivity: All nodes must be able to communicate on the specified port -2. Shared filesystem: All nodes must have access to the same dataset and checkpoint directories -3. Environment consistency: Same Python environment and dependencies on all nodes -4. Firewall configuration: Ensure the rendezvous port (e.g., 29400) is open between nodes -5. SSH access: Nodes should be able to SSH to each other (for torchrun coordination) -Environment Variables for Multi-Node: -- MASTER_ADDR: IP address of the master node (auto-set by torchrun) -- MASTER_PORT: Port for rendezvous (auto-set by torchrun) -- WORLD_SIZE: Total number of processes across all nodes -- RANK: Global rank of the process (0 to WORLD_SIZE-1) -- LOCAL_RANK: Local rank within the node (0 to nproc_per_node-1) -- NODE_RANK: Rank of the node (0 to nnodes-1) -Checkpoint Parameters: -- --ckpt_save_interval: Override the checkpoint save interval from config (e.g., --save_interval 500) -- --resume: Resume training from the latest checkpoint in the checkpoint directory -- --cleanup_checkpoints: Clean up corrupted checkpoints during resume (keeps last 3 valid ones) -- --overwrite: Overwrite existing checkpoint directory (cannot be used with --resume) -Memory Optimization Parameters: -- --gradient_accumulation_steps: Number of steps to accumulate gradients (default: 1) -- --mixed_precision: Enable mixed precision training (default: True) -- --max_memory_usage: Maximum GPU memory usage in GB (default: None, auto-detect) -- --gradckpt: Enable gradient checkpointing for memory optimization -Notes -- The global batch size must be divisible by world size (number of processes). -- The data pipeline and transforms are identical to the JAX version and are controlled - by the selected TrainConfig (e.g., `LeRobot*` configs for real datasets or `FakeDataConfig`). -- Supports Weights & Biases (wandb) logging for experiment tracking and visualization. -- Checkpoints include model state, optimizer state, and training metadata for complete resume capability. -- Checkpoints are saved in experiment-specific directories: // -- Resume functionality automatically finds the latest checkpoint for the specified experiment name. -- Checkpoint loading handles both PyTorch and JAX/Flax checkpoints for compatibility. -- For optimal multi-node performance, ensure high-bandwidth network connectivity (e.g., InfiniBand). -- Monitor GPU utilization and network bandwidth during multi-node training. -- Memory optimizations can significantly reduce GPU memory usage while maintaining training quality. + torchrun \ + --nnodes= --nproc_per_node= --node_rank= \ + --master_addr= --master_port= \ + scripts/train_pytorch.py --exp_name= --save_interval + """ + import argparse import dataclasses +import gc import logging import os import platform +import shutil import time -import gc -from dataclasses import dataclass -from typing import Any, Dict, Tuple +from typing import Any, Dict +import jax import numpy as np import torch import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data import DataLoader as TorchDataLoader -from torch.utils.data.distributed import DistributedSampler +import torch.nn.parallel +import torch.utils.data +import torch.utils.data.distributed +import tqdm import wandb -from tqdm import tqdm +import safetensors.torch import openpi.training.config as _config import openpi.training.data_loader as _data import openpi.models.model as _model -from openpi.models_pytorch.pi0_pytorch import PI0Pytorch -from openpi.models.pi0_config import Pi0Config +import openpi.models_pytorch.pi0_pytorch +import openpi.models.pi0_config def init_logging(): @@ -149,9 +98,9 @@ def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = T def setup_ddp(): world_size = int(os.environ.get("WORLD_SIZE", "1")) use_ddp = world_size > 1 - if use_ddp and not dist.is_initialized(): + if use_ddp and not torch.distributed.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method="env://") + torch.distributed.init_process_group(backend=backend, init_method="env://") # Set up debugging environment variables for DDP issues if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: @@ -165,9 +114,9 @@ def setup_ddp(): def cleanup_ddp(): - if dist.is_initialized(): - dist.barrier() - dist.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() def set_seed(seed: int, local_rank: int): @@ -194,9 +143,6 @@ def stack_leaf(*xs): # Memory-efficient collation result = torch.utils.data.default_collate(batch_list) if not isinstance(batch_list[0], dict) else _tree_map_multi(stack_leaf, batch_list) - # Clear batch list from memory - del batch_list - return result @@ -209,11 +155,7 @@ def recurse(keys, items): return recurse([], batch_list) -def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - # Maintain canonical image key order - image_keys = _model.IMAGE_KEYS - import jax - +def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: # Memory-efficient conversion: convert to torch tensors and move to device in one step batch = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(device), batch) @@ -221,32 +163,26 @@ def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Tuple[list[to batch['state'] = batch['state'].to(dtype=torch.float32) batch['actions'] = batch['actions'].to(dtype=torch.float32) - # Clear numpy arrays from memory if they exist - del jax - return batch def get_model_state_dict(model): """Get state dict from model, handling DDP wrapper.""" - return model.module.state_dict() if isinstance(model, DDP) else model.state_dict() + return model.module.state_dict() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.state_dict() def get_model_parameters(model): """Get parameters from model, handling DDP wrapper.""" - return model.module.parameters() if isinstance(model, DDP) else model.parameters() + return model.module.parameters() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.parameters() -def save_checkpoint(model, optimizer, global_step, config, is_main, ckpt_save_interval=None, ema_model=None): +def save_checkpoint(model, optimizer, global_step, config, is_main, ema_model=None): """Save a checkpoint with model state, optimizer state, EMA state, and metadata.""" if not is_main: return - # Use ckpt_save_interval if provided, otherwise use config.save_interval - save_interval = ckpt_save_interval if ckpt_save_interval is not None else config.save_interval - # Only save if it's time to save or if it's the final step - if (global_step % save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: + if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: # Ensure checkpoint_dir is a Path object and create the step-specific directory ckpt_dir = config.checkpoint_dir / f"{global_step}" ckpt_dir.mkdir(parents=True, exist_ok=True) @@ -291,53 +227,25 @@ def load_checkpoint(model, optimizer, checkpoint_dir, device, ema_model=None): ckpt_dir = checkpoint_dir / f"{latest_step}" # Load model state with error handling - try: - model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) - (model.module if isinstance(model, DDP) else model).load_state_dict(model_state_dict) - logging.info(f"Successfully loaded model state from step {latest_step}") - except Exception as e: - logging.error(f"Failed to load model state from step {latest_step}: {e}") - raise RuntimeError(f"Model checkpoint corrupted at step {latest_step}. Cannot resume training.") + model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) + (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model).load_state_dict(model_state_dict) + logging.info(f"Successfully loaded model state from step {latest_step}") # Load optimizer state with error handling and fallback - optimizer_loaded = False - try: - optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) - optimizer.load_state_dict(optimizer_state_dict) - optimizer_loaded = True - logging.info(f"Successfully loaded optimizer state from step {latest_step}") - except Exception as e: - logging.warning(f"Failed to load optimizer state from step {latest_step}: {e}") - logging.warning("Optimizer state corrupted. Will continue with fresh optimizer state.") - # Reset optimizer to fresh state - for param_group in optimizer.param_groups: - param_group['lr'] = param_group.get('lr', 1e-4) # Use default LR or current LR - optimizer.zero_grad() - optimizer_loaded = False + optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) + optimizer.load_state_dict(optimizer_state_dict) # Load EMA state if available - ema_loaded = False if ema_model is not None and (ckpt_dir / "ema_model.pt").exists(): - try: - ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device, weights_only=False) - ema_model.load_state_dict(ema_state_dict) - ema_loaded = True - logging.info(f"Successfully loaded EMA state from step {latest_step}") - except Exception as e: - logging.warning(f"Failed to load EMA state from step {latest_step}: {e}") - logging.warning("EMA state corrupted. Will continue without EMA.") - ema_loaded = False + ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device, weights_only=False) + ema_model.load_state_dict(ema_state_dict) + logging.info(f"Successfully loaded EMA state from step {latest_step}") # Load metadata (weights_only=False needed for older checkpoints that might contain JAX/Flax objects) - try: - metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) - global_step = metadata.get("global_step", latest_step) - logging.info(f"Successfully loaded metadata from step {latest_step}") - return global_step - except Exception as e: - logging.warning(f"Failed to load metadata from checkpoint: {e}") - logging.warning("Using checkpoint step number as global step") - return latest_step + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + global_step = metadata.get("global_step", latest_step) + logging.info(f"Successfully loaded metadata from step {latest_step}") + return global_step def get_latest_checkpoint_step(checkpoint_dir): @@ -414,97 +322,6 @@ def find_latest_valid_checkpoint(checkpoint_dir): return None -def cleanup_corrupted_checkpoints(checkpoint_dir, keep_last_n=3): - """Clean up corrupted checkpoints, keeping only the last N valid ones.""" - checkpoint_steps = [] - for d in checkpoint_dir.iterdir(): - if d.is_dir() and d.name.isdigit(): - checkpoint_steps.append(int(d.name)) - - if not checkpoint_steps: - return - - # Sort steps in descending order - checkpoint_steps.sort(reverse=True) - - valid_checkpoints = [] - corrupted_checkpoints = [] - - # Validate all checkpoints - for step in checkpoint_steps: - if validate_checkpoint_integrity(checkpoint_dir, step): - valid_checkpoints.append(step) - else: - corrupted_checkpoints.append(step) - - # Keep only the last N valid checkpoints - checkpoints_to_keep = valid_checkpoints[:keep_last_n] - checkpoints_to_remove = valid_checkpoints[keep_last_n:] + corrupted_checkpoints - - # Remove old valid checkpoints and all corrupted ones - for step in checkpoints_to_remove: - checkpoint_path = checkpoint_dir / f"{step}" - try: - import shutil - shutil.rmtree(checkpoint_path) - logging.info(f"Removed checkpoint at step {step}") - except Exception as e: - logging.warning(f"Failed to remove checkpoint at step {step}: {e}") - - logging.info(f"Checkpoint cleanup complete. Kept {len(checkpoints_to_keep)} valid checkpoints: {checkpoints_to_keep}") - - -def debug_unused_parameters(model, device): - """Debug function to identify unused parameters in the model.""" - if isinstance(model, DDP): - model = model.module - - logging.info("Checking for potentially unused parameters...") - - # Get all parameter names and their indices - param_info = {} - idx = 0 - for name, param in model.named_parameters(): - if param.requires_grad: - param_info[idx] = name - idx += 1 - - logging.info(f"Total trainable parameters: {len(param_info)}") - - # Check which parameters have gradients after a forward pass - # This is a diagnostic function that can be called if needed - return param_info - - -def check_model_parameters(model, device): - """Check for unused parameters and provide debugging information.""" - if isinstance(model, DDP): - model = model.module - - total_params = 0 - used_params = 0 - - for name, param in model.named_parameters(): - total_params += param.numel() - if param.requires_grad: - used_params += param.numel() - - logging.info(f"Model parameters: {total_params:,} total, {used_params:,} trainable") - - # Check for parameters that might be unused - unused_params = [] - for name, param in model.named_parameters(): - if param.requires_grad and param.grad is None: - unused_params.append(name) - - if unused_params: - logging.warning(f"Found {len(unused_params)} parameters that might be unused:") - for name in unused_params: # Show first 10 - logging.warning(f" - {name}") - # if len(unused_params) > 10: - # logging.warning(f" ... and {len(unused_params) - 10} more") - - def log_memory_usage(device, step, phase="unknown"): """Log detailed memory usage information.""" if not torch.cuda.is_available(): @@ -528,44 +345,11 @@ def log_memory_usage(device, step, phase="unknown"): logging.info(f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}") -def setup_memory_optimizations(model, device, enable_gradient_checkpointing=False): - """Setup memory optimization techniques for the model.""" - # Set memory optimization environment variables - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - os.environ["CUDA_LAUNCH_BLOCKING"] = "0" - - if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): - model.gradient_checkpointing_enable() - logging.info("Enabled gradient checkpointing for memory optimization") - - # Enable memory efficient attention if available - if hasattr(model, 'config') and hasattr(model.config, 'attention_mode'): - model.config.attention_mode = 'flash_attention_2' - logging.info("Enabled Flash Attention 2 for memory efficiency") - - # Set memory efficient settings - if torch.cuda.is_available(): - # Enable memory efficient algorithms - torch.backends.cudnn.benchmark = False # Disable for memory efficiency - torch.backends.cudnn.deterministic = True # Enable for memory efficiency - - # Set memory fraction if needed - if device.index is not None: - torch.cuda.empty_cache() - logging.info(f"Cleared CUDA cache for device {device.index}") - - logging.info("Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to reduce memory fragmentation") - - -def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_interval: int = None, gradient_accumulation_steps: int = 1, mixed_precision: bool = True, max_memory_usage: float = None, enable_gradient_checkpointing: bool = False, cleanup_checkpoints: bool = False): +def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradient_checkpointing: bool = False): use_ddp, local_rank, device = setup_ddp() is_main = (not use_ddp) or (dist.get_rank() == 0) set_seed(config.seed, local_rank) - # Memory optimization: Set memory fraction if specified - if max_memory_usage is not None and torch.cuda.is_available(): - torch.cuda.set_per_process_memory_fraction(max_memory_usage / torch.cuda.get_device_properties(device).total_memory * 1e-9) - # Initialize checkpoint directory and wandb resuming = False if resume: @@ -577,17 +361,11 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte if latest_step is not None: resuming = True logging.info(f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}") - - # Clean up corrupted checkpoints if requested - if cleanup_checkpoints and is_main: - logging.info("Cleaning up corrupted checkpoints...") - cleanup_corrupted_checkpoints(exp_checkpoint_dir, keep_last_n=3) else: raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") else: raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") elif config.overwrite and config.checkpoint_dir.exists(): - import shutil shutil.rmtree(config.checkpoint_dir) logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") @@ -609,16 +387,12 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte dataset, data_conf = build_datasets(config) sampler = None if use_ddp: - sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True, drop_last=True) - - # Reduce batch size for gradient accumulation - effective_batch_size = config.batch_size // (dist.get_world_size() if use_ddp else 1) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), shuffle=True, drop_last=True) - # Memory-efficient data loading with reduced pin_memory for large datasets - pin_memory = False # Disable pin_memory to reduce memory usage - logging.info("Disabled pin_memory to reduce memory usage") + # Use full batch size since we removed gradient accumulation + effective_batch_size = config.batch_size // (torch.distributed.get_world_size() if use_ddp else 1) - loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) + loader = torch.utils.data.DataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_to_numpy) # Log sample images to wandb on first batch if is_main and config.wandb_enabled and not resuming: @@ -638,18 +412,15 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte wandb.log({"camera_views": images_to_log}, step=0) # Clear sample batch from memory - del sample_batch, images_to_log torch.cuda.empty_cache() if torch.cuda.is_available() else None # Reset the loader iterator - loader = TorchDataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=pin_memory, drop_last=True, collate_fn=collate_to_numpy) - - # Test gradient checkpointing with a small forward pass (moved to after model creation) + loader = torch.utils.data.DataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_to_numpy) # Build model - if not isinstance(config.model, Pi0Config): + if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): # Convert dataclass to Pi0Config if needed - model_cfg = Pi0Config( + model_cfg = openpi.models.pi0_config.Pi0Config( action_dim=config.model.action_dim, action_horizon=config.model.action_horizon, max_token_len=config.model.max_token_len, @@ -660,86 +431,19 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte else: model_cfg = config.model - model = PI0Pytorch(model_cfg).to(device) + model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) - # Apply memory optimizations - setup_memory_optimizations(model, device, enable_gradient_checkpointing) + if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): + model.gradient_checkpointing_enable() + logging.info("Enabled gradient checkpointing for memory optimization") # Log initial memory usage after model creation if is_main and torch.cuda.is_available(): log_memory_usage(device, 0, "after_model_creation") - # Log gradient checkpointing status if enabled - if enable_gradient_checkpointing and is_main: - if hasattr(model, 'get_gradient_checkpointing_status'): - status = model.get_gradient_checkpointing_status() - logging.info(f"Gradient checkpointing status: {status}") - - # Verify that gradient checkpointing is actually enabled - if hasattr(model, 'is_gradient_checkpointing_enabled'): - is_enabled = model.is_gradient_checkpointing_enabled() - logging.info(f"Gradient checkpointing is enabled: {is_enabled}") - - # Check if we're in training mode - logging.info(f"Model training mode: {model.training}") - - # Verify the underlying models have gradient checkpointing enabled - if hasattr(model, 'paligemma_with_expert'): - if hasattr(model.paligemma_with_expert, 'paligemma'): - if hasattr(model.paligemma_with_expert.paligemma, 'language_model'): - paligemma_gc = getattr(model.paligemma_with_expert.paligemma.language_model, 'gradient_checkpointing', False) - logging.info(f"PaliGemma language model gradient checkpointing: {paligemma_gc}") - - if hasattr(model.paligemma_with_expert.paligemma, 'vision_tower'): - vision_gc = getattr(model.paligemma_with_expert.paligemma.vision_tower, 'gradient_checkpointing', False) - logging.info(f"PaliGemma vision tower gradient checkpointing: {vision_gc}") - - if hasattr(model.paligemma_with_expert, 'gemma_expert'): - if hasattr(model.paligemma_with_expert.gemma_expert, 'model'): - gemma_gc = getattr(model.paligemma_with_expert.gemma_expert.model, 'gradient_checkpointing', False) - logging.info(f"Gemma expert model gradient checkpointing: {gemma_gc}") - else: - logging.info("Gradient checkpointing enabled but status check not available") - - # Test gradient checkpointing with a small forward pass - if is_main and enable_gradient_checkpointing: - logging.info("Testing gradient checkpointing with a small forward pass...") - try: - # Create a small test batch - test_batch = next(iter(loader)) - test_batch = batch_to_torch(test_batch, device) - test_actions = test_batch["actions"] - - # Record memory before forward pass - if torch.cuda.is_available(): - memory_before = torch.cuda.memory_allocated(device) / 1e9 - logging.info(f"Memory before test forward pass: {memory_before:.2f}GB") - - # Do a test forward pass - with torch.no_grad(): - test_observation = _model.Observation.from_dict(test_batch) - test_losses = model(test_observation, test_actions) - - # Record memory after forward pass - if torch.cuda.is_available(): - memory_after = torch.cuda.memory_allocated(device) / 1e9 - logging.info(f"Memory after test forward pass: {memory_after:.2f}GB") - logging.info(f"Memory difference: {memory_after - memory_before:.2f}GB") - - # Clear test data - del test_batch, test_actions, test_observation, test_losses - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - logging.info("Gradient checkpointing test completed successfully") - except Exception as e: - logging.warning(f"Gradient checkpointing test failed: {e}") - logging.warning("Continuing with training...") - if use_ddp: # Enable unused parameter detection to handle cases where some parameters don't participate in loss - model = DDP(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=True) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=True) # Load weights from weight_loader if specified (for fine-tuning) if isinstance(config.weight_loader, str): @@ -747,8 +451,7 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte logging.info(f"Loading weights from: {weight_path}") model_path = os.path.join(weight_path, "model.safetensors") - from safetensors.torch import load_model - load_model((model.module if isinstance(model, DDP) else model), model_path) + safetensors.torch.load_model((model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path) logging.info(f"Loaded PyTorch weights from {weight_path}") # Optimizer + learning rate schedule from config @@ -769,20 +472,15 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, ckpt_save_inte # Initialize EMA if specified in config ema_model = None if config.ema_decay is not None: - try: - ema_model = PI0Pytorch(model_cfg).to(device) - - # Get the correct state dict from the main model - main_model_state_dict = get_model_state_dict(model) - - # Load the state dict into EMA model - ema_model.load_state_dict(main_model_state_dict) - ema_model.eval() - logging.info(f"Initialized EMA with decay {config.ema_decay}") - except Exception as e: - logging.error(f"Failed to initialize EMA model: {e}") - logging.error("Continuing without EMA...") - ema_model = None + ema_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) + + # Get the correct state dict from the main model + main_model_state_dict = get_model_state_dict(model) + + # Load the state dict into EMA model + ema_model.load_state_dict(main_model_state_dict) + ema_model.eval() + logging.info(f"Initialized EMA with decay {config.ema_decay}") # Load checkpoint if resuming global_step = 0 @@ -798,37 +496,20 @@ def lr_schedule(step: int): cos = 0.5 * (1 + np.cos(np.pi * progress)) return end_lr + (peak_lr - end_lr) * cos - # Enable mixed precision training for memory optimization - scaler = torch.amp.GradScaler(enabled=mixed_precision and torch.cuda.is_available()) - - # Set memory efficient settings - if torch.cuda.is_available(): - # Enable memory efficient algorithms - torch.backends.cudnn.benchmark = False # Disable for memory efficiency - torch.backends.cudnn.deterministic = True # Enable for memory efficiency - - # Set memory fraction if needed - if device.index is not None: - torch.cuda.empty_cache() - logging.info(f"Cleared CUDA cache for device {device.index}") - model.train() start_time = time.time() infos = [] # Collect stats over log interval if is_main: - logging.info(f"Running on: {platform.node()} | world_size={dist.get_world_size() if use_ddp else 1}") + logging.info(f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}") logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}") - logging.info(f"Memory optimizations: gradient_accumulation_steps={gradient_accumulation_steps}, mixed_precision={mixed_precision}, gradient_checkpointing={enable_gradient_checkpointing}") + logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") logging.info(f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}") logging.info(f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}") if config.ema_decay is not None: logging.info(f"EMA decay: {config.ema_decay}") # Training loop - iterate until we reach num_train_steps - pbar = tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None - - # Check model parameters after first few steps when gradients are available - parameters_checked = False + pbar = tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None while global_step < config.num_train_steps: if use_ddp: @@ -847,90 +528,58 @@ def lr_schedule(step: int): for pg in optim.param_groups: pg["lr"] = lr_schedule(global_step) - # Forward pass with mixed precision + # Forward pass observation = _model.Observation.from_dict(batch) - try: - with torch.amp.autocast('cuda', enabled=mixed_precision and torch.cuda.is_available()): - losses = model(observation, actions) - # Ensure losses is a tensor and handle different return types - if isinstance(losses, (list, tuple)): - losses = torch.stack(losses) - elif not isinstance(losses, torch.Tensor): - losses = torch.tensor(losses, device=device, dtype=torch.float32) - - loss = losses.mean() / gradient_accumulation_steps # Scale loss for gradient accumulation - - # Debug gradient checkpointing on first few steps - if global_step < 5 and is_main: - if hasattr(model, 'is_gradient_checkpointing_enabled'): - gc_enabled = model.is_gradient_checkpointing_enabled() - logging.info(f"Step {global_step}: Gradient checkpointing enabled: {gc_enabled}") - if torch.cuda.is_available(): - log_memory_usage(device, global_step, "after_forward") - except RuntimeError as e: - if "Expected to have finished reduction" in str(e) or "did not receive grad" in str(e): - logging.error(f"DDP error on rank {dist.get_rank() if use_ddp else 0}: {e}") - logging.error("This usually indicates unused parameters in the model.") - logging.error("Try setting TORCH_DISTRIBUTED_DEBUG=DETAIL for more information.") - raise - else: - raise - - # Backward pass with gradient scaling - scaler.scale(loss).backward() + losses = model(observation, actions) + # Ensure losses is a tensor and handle different return types + if isinstance(losses, (list, tuple)): + losses = torch.stack(losses) + elif not isinstance(losses, torch.Tensor): + losses = torch.tensor(losses, device=device, dtype=torch.float32) - # Aggressive memory cleanup after backward pass - if torch.cuda.is_available(): - # Clear intermediate activations that might still be in memory - torch.cuda.empty_cache() - gc.collect() - - # Log memory usage after backward pass for debugging - if global_step < 5 and is_main: + loss = losses.mean() + + # Backward pass + loss.backward() + + # Log memory usage after backward pass + if global_step < 5 and is_main: + if torch.cuda.is_available(): log_memory_usage(device, global_step, "after_backward") - # Gradient accumulation logic - if (global_step + 1) % gradient_accumulation_steps == 0: - # Unscale gradients for clipping - scaler.unscale_(optim) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) - - # Optimizer step - scaler.step(optim) - scaler.update() - optim.zero_grad(set_to_none=True) - - # Clear gradients more aggressively - for param in model.parameters(): - if param.grad is not None: - param.grad.detach_() - param.grad = None - - # Update EMA if enabled - if ema_model is not None: - try: - with torch.no_grad(): - # Get parameters from the correct model structure - main_model_params = get_model_parameters(model) - for param, ema_param in zip(main_model_params, ema_model.parameters()): - ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) - except Exception as e: - logging.warning(f"Failed to update EMA model: {e}") - # Continue training without EMA update - - # # Check model parameters after first few steps when gradients are available - # if not parameters_checked and global_step >= 16510 and is_main: - # check_model_parameters(model, device) - # parameters_checked = True - - # Collect stats (only on accumulation steps) - if (global_step + 1) % gradient_accumulation_steps == 0 and is_main: + # Gradient clipping + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) + + # Optimizer step + optim.step() + optim.zero_grad(set_to_none=True) + + # Clear gradients more aggressively + for param in model.parameters(): + if param.grad is not None: + param.grad.detach_() + param.grad = None + + # Update EMA if enabled + if ema_model is not None: + try: + with torch.no_grad(): + # Get parameters from the correct model structure + main_model_params = get_model_parameters(model) + for param, ema_param in zip(main_model_params, ema_model.parameters()): + ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) + except Exception as e: + logging.warning(f"Failed to update EMA model: {e}") + # Continue training without EMA update + + # Collect stats + if is_main: infos.append({ - "loss": loss.item() * gradient_accumulation_steps, # Unscale for logging + "loss": loss.item(), "learning_rate": optim.param_groups[0]['lr'], }) - if is_main and (global_step % config.log_interval == 0) and (global_step + 1) % gradient_accumulation_steps == 0: + if is_main and (global_step % config.log_interval == 0): elapsed = time.time() - start_time # Average stats over log interval @@ -952,7 +601,7 @@ def lr_schedule(step: int): infos = [] # Reset stats collection # Save checkpoint using the new mechanism - save_checkpoint(model, optim, global_step, config, is_main, ckpt_save_interval, ema_model) + save_checkpoint(model, optim, global_step, config, is_main, ema_model) global_step += 1 @@ -960,25 +609,10 @@ def lr_schedule(step: int): if pbar is not None: pbar.update(1) pbar.set_postfix({ - 'loss': f'{loss.item() * gradient_accumulation_steps:.4f}', + 'loss': f'{loss.item():.4f}', 'lr': f'{optim.param_groups[0]["lr"]:.2e}', 'step': global_step }) - - # Memory cleanup after each batch - del batch, actions, observation, losses, loss - - # More aggressive memory cleanup - if torch.cuda.is_available(): - torch.cuda.empty_cache() - # Force garbage collection - gc.collect() - - # Log memory usage for debugging gradient checkpointing - if is_main and global_step % 100 == 0: - memory_allocated = torch.cuda.memory_allocated(device) / 1e9 - memory_reserved = torch.cuda.memory_reserved(device) / 1e9 - logging.info(f"Step {global_step}: GPU memory allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB") # Close progress bar if pbar is not None: @@ -996,43 +630,17 @@ def main(): config = _config.cli() # Parse additional command line arguments for memory optimization - import argparse parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--resume", action="store_true", default=False, help="Resume training from the latest checkpoint for the experiment (handles both PyTorch and JAX checkpoints)") - parser.add_argument("--cleanup_checkpoints", action="store_true", default=False, - help="Clean up corrupted checkpoints during resume (keeps last 3 valid checkpoints)") - parser.add_argument("--ckpt_save_interval", type=int, default=None, - help="Interval for saving checkpoints (overrides config.save_interval)") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1, - help="Number of steps to accumulate gradients (default: 1)") - parser.add_argument("--mixed_precision", action="store_true", default=False, - help="Enable mixed precision training (default: True)") - parser.add_argument("--no_mixed_precision", action="store_true", default=True, - help="Disable mixed precision training") - parser.add_argument("--max_memory_usage", type=float, default=None, - help="Maximum GPU memory usage in GB (default: None, auto-detect)") - parser.add_argument("--gradckpt", action="store_true", default=False, + + parser.add_argument("--enable_gradient_checkpointing", action="store_true", default=True, help="Enable gradient checkpointing for memory optimization") - parser.add_argument("--ddp_debug_level", type=str, default="INFO", choices=["INFO", "DETAIL", "OFF"], - help="DDP debugging level (default: INFO)") args, _ = parser.parse_known_args() - - # Handle mixed precision flag - mixed_precision = args.mixed_precision and not args.no_mixed_precision - - # Set DDP debug level - if args.ddp_debug_level != "OFF": - os.environ["TORCH_DISTRIBUTED_DEBUG"] = args.ddp_debug_level train_loop(config, resume=args.resume, - ckpt_save_interval=args.ckpt_save_interval, - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=mixed_precision, - max_memory_usage=args.max_memory_usage, - enable_gradient_checkpointing=True, - cleanup_checkpoints=args.cleanup_checkpoints) + enable_gradient_checkpointing=args.enable_gradient_checkpointing) if __name__ == "__main__": diff --git a/scripts/train_single_example.py b/scripts/train_single_example.py index 9418342..a2e90bb 100644 --- a/scripts/train_single_example.py +++ b/scripts/train_single_example.py @@ -1,9 +1,39 @@ +#!/usr/bin/env python3 """ Train on a single example for debugging JAX vs PyTorch comparison. + This script creates a deterministic dataset with one example and trains on it to help debug differences between JAX and PyTorch implementations. + +Usage examples: + +# Test pi05_droid model +python scripts/train_single_example.py \ + --model_name pi05_droid \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch + +# Test pi0_droid model +python scripts/train_single_example.py \ + --model_name pi0_droid \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch + +# Test pi0_aloha_sim model +python scripts/train_single_example.py \ + --model_name pi0_aloha_sim \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch + +This script: +- Creates a fixed example with deterministic random values +- Uses the same noise and time values for both JAX and PyTorch +- Disables preprocessing for fair comparison +- Compares losses between implementations +- Provides detailed analysis of differences """ +import argparse import logging import numpy as np import torch @@ -11,11 +41,13 @@ import jax.numpy as jnp import flax.nnx as nnx import flax +import safetensors from unittest.mock import patch from openpi.models import model as _model from openpi.models.pi0_config import Pi0Config from openpi.models_pytorch.pi0_pytorch import PI0Pytorch +import openpi.training.config def setup_logging(): @@ -26,15 +58,15 @@ def setup_logging(): ) -def create_fixed_example(): +def create_fixed_example(model_config): """Create a fixed example for debugging.""" np.random.seed(42) batch_size = 1 - action_dim = 32 - action_horizon = 10 + action_dim = model_config.action_dim + action_horizon = model_config.action_horizon image_size = 224 - max_token_len = 48 + max_token_len = model_config.max_token_len # Create fixed images images = {} @@ -94,23 +126,22 @@ def mock_preprocess_observation_pytorch(observation, **kwargs): return observation -def test_pytorch_single_example(noise, time): +def test_pytorch_single_example(noise, time, model_name, pytorch_checkpoint_dir): """Test PyTorch training on single example.""" print("\n=== Testing PyTorch on Single Example ===") - # Create model - config = Pi0Config(action_dim=32, action_horizon=10, pi05=True) - model = PI0Pytorch(config) + # Create model using the training config + train_config = openpi.training.config.get_config(model_name) + model = PI0Pytorch(train_config.model) # Use train_config.model instead of train_config # Load pre-trained weights - weight_path = "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2/model.safetensors" - print(f"Loading PyTorch weights from: {weight_path}") + pytorch_checkpoint_dir = pytorch_checkpoint_dir + "/model.safetensors" + print(f"Loading PyTorch weights from: {pytorch_checkpoint_dir}") - from safetensors.torch import load_model - load_model(model, weight_path) + safetensors.torch.load_model(model, pytorch_checkpoint_dir) # Create fixed example - example = create_fixed_example() + example = create_fixed_example(train_config.model) # Convert to PyTorch tensors pytorch_example = {} @@ -146,213 +177,40 @@ def test_pytorch_single_example(noise, time): # Test forward pass with fixed noise and time model.eval() with torch.no_grad(): - #try: - # Use mock to disable preprocessing - with patch('openpi.models.model.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): - losses = model(observation, actions, noise=noise_tensor, time=time_tensor) - print(f"PyTorch forward pass successful!") - print(f"Losses shape: {losses.shape}") - print(f"Losses dtype: {losses.dtype}") - mean_loss = losses.to(torch.float32).mean().item() - print(f"Mean loss: {mean_loss:.6f}") - return True, losses - # except Exception as e: - # print(f"PyTorch forward pass failed: {e}") - # return False, None - - -def test_jax_single_example(noise, time, debug_single_layer=False): + try: + # Use mock to disable preprocessing + with patch('openpi.models.model.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): + losses = model(observation, actions, noise=noise_tensor, time=time_tensor) + print(f"PyTorch forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + mean_loss = losses.to(torch.float32).mean().item() + print(f"Mean loss: {mean_loss:.6f}") + return True, losses + except Exception as e: + print(f"PyTorch forward pass failed: {e}") + return False, None + + +def test_jax_single_example(noise, time, model_name, jax_checkpoint_dir): """Test JAX training on single example.""" print("\n=== Testing JAX on Single Example ===") - # Create model - config = Pi0Config(action_dim=32, action_horizon=10, pi05=True) - if debug_single_layer: - print("šŸ”§ Debug mode: Using only 1 encoder layer") - - # Create a custom model with modified siglip depth for debugging - if debug_single_layer: - # Import the Pi0 model class - from openpi.models.pi0 import Pi0 - import openpi.models.gemma as _gemma - import openpi.models.siglip as _siglip - import flax.nnx.bridge as nnx_bridge - - # Create the model manually with custom siglip variant - rng = jax.random.key(42) - rngs = flax.nnx.Rngs(rng) - - paligemma_config = _gemma.get_config(config.paligemma_variant) - action_expert_config = _gemma.get_config(config.action_expert_variant) - - # Create LLM - llm = nnx_bridge.ToNNX( - _gemma.Module( - configs=[paligemma_config, action_expert_config], - embed_dtype=config.dtype, - adarms=config.pi05, - ) - ) - llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False]) - - # Create custom siglip model with depth=1 - # We'll use the same variant but override the depth parameter - siglip_params = _siglip.decode_variant("So400m/14") - siglip_params["depth"] = 1 # Override depth to 1 for debugging - - img = nnx_bridge.ToNNX( - _siglip.Module( - num_classes=paligemma_config.width, - variant=None, # Don't use variant, use explicit params - pool_type="none", - scan=False, # Disable scan for single layer - dtype_mm=config.dtype, - **siglip_params, # Pass the modified parameters - ) - ) - img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) - - # Create the full model - model = Pi0(config, rngs) - # Replace the siglip model with our custom one - model.PaliGemma.img = img - - print("šŸ”§ Created single-layer SigLIP model (depth=1) for debugging...") - else: - rng = jax.random.key(42) - model = config.create(rng) + # Create model using the training config + train_config = openpi.training.config.get_config(model_name) + rng = jax.random.key(42) + model = train_config.model.create(rng) # Use train_config.model instead of config.create # Load pre-trained weights - weight_path = "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" - print(f"Loading JAX weights from: {weight_path}") - - # try: - # Use the same approach as in policy_config.py - params = _model.restore_params(weight_path, dtype=jnp.bfloat16) - - # Filter params to only include the first encoder layer for debugging - if debug_single_layer: - filtered_params = {} - - # The parameters are nested, so we need to traverse the structure - def filter_nested_params(params_dict, key_path=""): - result = {} - for key, value in params_dict.items(): - current_path = f"{key_path}.{key}" if key_path else key - - if isinstance(value, dict): - # Recursive case - traverse deeper - filtered_sub = filter_nested_params(value, current_path) - if filtered_sub: # Only include if there are sub-parameters - result[key] = filtered_sub - else: - # Leaf case - check if this parameter should be included - if 'Transformer' in current_path: - # Only keep the first encoder block (encoderblock_0) and encoder_norm - if 'encoderblock_0' in current_path or 'encoder_norm' in current_path: - result[key] = value - else: - # Keep all non-Transformer params - result[key] = value - return result - - filtered_params = filter_nested_params(params) - params_to_use = filtered_params - print("āœ… JAX weights loaded successfully (first layer only)!") - print("āš ļø Note: Using So400m variant with depth=1 (modified from depth=27)") - print("āš ļø Only the first layer weights will be used, others will be randomly initialized") - - # Debug: Show what parameters we have - print(f"šŸ“‹ Available parameters for single-layer model:") - transformer_params = [] - for key in sorted(params_to_use.keys()): - if 'Transformer' in key: - transformer_params.append(key) - print(f" {key}: {params_to_use[key].shape}") - - if not transformer_params: - print(" No Transformer parameters found! Let's see all keys:") - for key in sorted(params_to_use.keys())[:20]: # Show first 20 keys - print(f" {key}") - - # Let's also check if PaliGemma has nested structure - if 'PaliGemma' in params_to_use: - print(" Checking PaliGemma structure:") - paligemma_params = params_to_use['PaliGemma'] - if hasattr(paligemma_params, 'keys'): - for subkey in sorted(paligemma_params.keys()): - print(f" PaliGemma.{subkey}") - if hasattr(paligemma_params[subkey], 'keys'): - for subsubkey in sorted(paligemma_params[subkey].keys()): - print(f" PaliGemma.{subkey}.{subsubkey}") - if hasattr(paligemma_params[subkey][subsubkey], 'keys'): - for subsubsubkey in sorted(paligemma_params[subkey][subsubkey].keys()): - if 'Transformer' in subsubsubkey: - print(f" PaliGemma.{subkey}.{subsubkey}.{subsubsubkey}") - - # The issue is that with scan=False, the model expects different parameter names - # We need to map from encoderblock_0 to encoderblock in the nested structure - def adapt_nested_params(params_dict, key_path=""): - result = {} - for key, value in params_dict.items(): - current_path = f"{key_path}.{key}" if key_path else key - - if isinstance(value, dict): - # Recursive case - traverse deeper - result[key] = adapt_nested_params(value, current_path) - else: - # Leaf case - adapt the key if needed - new_key = key - if 'Transformer' in current_path and 'encoderblock_0' in key: - # Map encoderblock_0 to encoderblock for non-scan mode - new_key = key.replace('encoderblock_0', 'encoderblock') - result[new_key] = value - return result - - adapted_params = adapt_nested_params(params_to_use) - params_to_use = adapted_params - print("šŸ”„ Adapted parameter names for non-scan mode") - print(f" Example mapping: encoderblock_0 -> encoderblock") - else: - params_to_use = params - print("āœ… JAX weights loaded successfully!") + print(f"Loading JAX weights from: {jax_checkpoint_dir}") + params = _model.restore_params(jax_checkpoint_dir, dtype=jnp.bfloat16) + params_to_use = params + print("āœ… JAX weights loaded successfully!") # Apply the params to the model using NNX state management import flax.nnx as nnx graphdef, model_state = nnx.split(model) - # Debug: Let me check what the model actually expects first - print(f"šŸ” Checking what the model expects...") - try: - print(f"šŸ“‹ Model parameter structure:") - model_transformer_params = [] - for key in sorted(model_state.keys()): - if 'Transformer' in key: - model_transformer_params.append(key) - print(f" {key}: shape {getattr(model_state[key], 'shape', 'no shape')}") - - if not model_transformer_params: - print(" No Transformer parameters found in model! Let's see all keys:") - for key in sorted(model_state.keys())[:20]: # Show first 20 keys - print(f" {key}") - - # Let's also check if PaliGemma has nested structure in model - if 'PaliGemma' in model_state: - print(" Checking PaliGemma structure in model:") - paligemma_state = model_state['PaliGemma'] - if hasattr(paligemma_state, 'keys'): - for subkey in sorted(paligemma_state.keys()): - print(f" PaliGemma.{subkey}") - if hasattr(paligemma_state[subkey], 'keys'): - for subsubkey in sorted(paligemma_state[subkey].keys()): - print(f" PaliGemma.{subkey}.{subsubkey}") - if hasattr(paligemma_state[subkey][subsubkey], 'keys'): - for subsubsubkey in sorted(paligemma_state[subkey][subsubkey].keys()): - if 'Transformer' in subsubsubkey: - print(f" PaliGemma.{subkey}.{subsubkey}.{subsubsubkey}") - except Exception as e: - print(f" Could not inspect model parameters: {e}") - # Now try to load parameters try: model_state.replace_by_pure_dict(params_to_use) @@ -362,12 +220,9 @@ def adapt_nested_params(params_dict, key_path=""): print(f"āŒ Parameter loading failed: {e}") print("šŸ”„ Continuing with random initialization...") model = nnx.merge(graphdef, model_state) - # except Exception as e: - # print(f"āŒ Failed to load JAX weights: {e}") - # print("Continuing with random initialization...") # Create fixed example - example = create_fixed_example() + example = create_fixed_example(train_config.model) # Convert to JAX arrays jax_example = {} @@ -395,20 +250,20 @@ def adapt_nested_params(params_dict, key_path=""): print(f"Time shape: {time_jax.shape}, dtype: {time_jax.dtype}") # Test forward pass with fixed noise and time - # try: - # Use the modified compute_loss method that accepts external noise and time - # Use mock to disable preprocessing - with patch('openpi.models.model.preprocess_observation', side_effect=mock_preprocess_observation): - losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) - print(f"JAX forward pass successful!") - print(f"Losses shape: {losses.shape}") - print(f"Losses dtype: {losses.dtype}") - mean_loss = jnp.mean(losses).item() - print(f"Mean loss: {mean_loss:.6f}") - return True, losses - # except Exception as e: - # print(f"JAX forward pass failed: {e}") - # return False, None + try: + # Use the modified compute_loss method that accepts external noise and time + # Use mock to disable preprocessing + with patch('openpi.models.model.preprocess_observation', side_effect=mock_preprocess_observation): + losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) + print(f"JAX forward pass successful!") + print(f"Losses shape: {losses.shape}") + print(f"Losses dtype: {losses.dtype}") + mean_loss = jnp.mean(losses).item() + print(f"Mean loss: {mean_loss:.6f}") + return True, losses + except Exception as e: + print(f"JAX forward pass failed: {e}") + return False, None def compare_losses(pytorch_loss, jax_loss): @@ -423,29 +278,6 @@ def compare_losses(pytorch_loss, jax_loss): print(f"PyTorch loss: {pytorch_loss}") print(f"JAX loss: {jax_loss}") - # # Handle tensor inputs by computing mean if needed - # if hasattr(pytorch_loss, 'mean'): - # pytorch_mean = pytorch_loss.to(torch.float32).mean().item() - # pytorch_std = pytorch_loss.to(torch.float32).std().item() - # print(f"PyTorch loss tensor - Mean: {pytorch_mean:.8f}, Std: {pytorch_std:.8f}") - # print(f"PyTorch loss shape: {pytorch_loss.shape}") - # else: - # pytorch_mean = float(pytorch_loss) - # pytorch_std = 0.0 - # print(f"PyTorch loss scalar: {pytorch_mean:.8f}") - - # if hasattr(jax_loss, 'mean'): - # jax_mean = jax_loss.mean().item() - # jax_std = jax_loss.std().item() - # print(f"JAX loss tensor - Mean: {jax_mean:.8f}, Std: {jax_std:.8f}") - # print(f"JAX loss shape: {jax_loss.shape}") - # else: - # jax_mean = float(jax_loss) - # jax_std = 0.0 - # print(f"JAX loss scalar: {jax_mean:.8f}") - - - # Additional tensor analysis if both are tensors pytorch_loss = pytorch_loss.to(torch.float32) jax_loss = jax_loss.astype(jnp.float32) @@ -513,28 +345,43 @@ def compare_losses(pytorch_loss, jax_loss): def main(): """Main function to test both implementations.""" + parser = argparse.ArgumentParser(description="Train on a single example for JAX vs PyTorch comparison") + parser.add_argument("--model_name", type=str, default="pi05_droid", + choices=["pi0_aloha_sim", "pi0_aloha_towel", "pi0_base", "pi05_droid", "pi0_droid", "pi0_libero", "pi05_libero"], + help="Model name to use") + parser.add_argument("--jax_checkpoint_dir", type=str, required=True, + help="Directory containing JAX model checkpoints") + parser.add_argument("--pytorch_checkpoint_dir", type=str, required=True, + help="Directory containing PyTorch model checkpoints") + args = parser.parse_args() + setup_logging() print("šŸš€ Testing Single Example Training for JAX vs PyTorch Comparison") print("=" * 70) - print("šŸ“ Loading pre-trained weights for both models...") + print(f"šŸ“ Model: {args.model_name}") + print(f"šŸ“ JAX checkpoint: {args.jax_checkpoint_dir}") + print(f"šŸ“ PyTorch checkpoint: {args.pytorch_checkpoint_dir}") print("šŸŽÆ Using fixed noise and time values for deterministic comparison...") - print("šŸ”§ Debug mode: JAX model will use only 1 encoder layer for faster debugging...") print("🚫 Preprocessing disabled: Image augmentations and resizing are bypassed for fair comparison...") + # Get model configuration + train_config = openpi.training.config.get_config(args.model_name) + model_config = train_config.model + # Generate fixed noise and time noise, time = create_fixed_noise_and_time( batch_size=1, - action_horizon=10, - action_dim=32 + action_horizon=model_config.action_horizon, + action_dim=model_config.action_dim ) # Test PyTorch - pytorch_success, pytorch_losses = test_pytorch_single_example(noise, time) + pytorch_success, pytorch_losses = test_pytorch_single_example(noise, time, args.model_name, args.pytorch_checkpoint_dir) torch.cuda.empty_cache() # Test JAX - jax_success, jax_losses = test_jax_single_example(noise, time, debug_single_layer=False) + jax_success, jax_losses = test_jax_single_example(noise, time, args.model_name, args.jax_checkpoint_dir) # Compare losses if pytorch_success and jax_success: diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index 2c07776..8784da5 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -208,7 +208,7 @@ def preprocess_observation( ) -def preprocess_observation_pytorch_torch_compile( +def preprocess_observation_pytorch( observation, *, train: bool = False, @@ -364,158 +364,6 @@ def __init__(self, **kwargs): ) -def preprocess_observation_pytorch( - observation: Observation, - *, - train: bool = False, - image_keys: Sequence[str] = IMAGE_KEYS, - image_resolution: tuple[int, int] = IMAGE_RESOLUTION, -) -> Observation: - """PyTorch version of preprocess_observation. Preprocesses observations with PyTorch tensors by performing - image resizing (if necessary) and filling in a default image mask (if necessary). - - Note: Image augmentation is not implemented for PyTorch tensors as augmax is JAX-specific. - """ - - if not set(image_keys).issubset(observation.images): - raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") - - batch_shape = observation.state.shape[:-1] - - out_images = {} - for key in image_keys: - image = observation.images[key] - - # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats - # Handle both [B, C, H, W] and [B, H, W, C] formats - is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 - - if is_channels_first: - # Convert [B, C, H, W] to [B, H, W, C] for processing - image = image.permute(0, 2, 3, 1) - - if image.shape[1:3] != image_resolution: - logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") - image = image_tools.resize_with_pad_torch(image, *image_resolution) - - if train: - # Convert from [-1, 1] to [0, 1] for PyTorch augmentations - image = image / 2.0 + 0.5 - - # Apply PyTorch-based augmentations - if "wrist" not in key: - # Geometric augmentations for non-wrist cameras - height, width = image.shape[1:3] - - # Random crop and resize - crop_height = int(height * 0.95) - crop_width = int(width * 0.95) - - # Random crop - max_h = height - crop_height - max_w = width - crop_width - if max_h > 0 and max_w > 0: - # Use tensor operations instead of .item() for torch.compile compatibility - start_h = torch.randint(0, max_h + 1, (1,), device=image.device) - start_w = torch.randint(0, max_w + 1, (1,), device=image.device) - image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] - - # Resize back to original size - image = torch.nn.functional.interpolate( - image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - size=(height, width), - mode='bilinear', - align_corners=False - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Random rotation (small angles) - # Use tensor operations instead of .item() for torch.compile compatibility - angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees - if torch.abs(angle) > 0.1: # Only rotate if angle is significant - # Convert to radians - angle_rad = angle * torch.pi / 180.0 - - # Create rotation matrix - cos_a = torch.cos(angle_rad) - sin_a = torch.sin(angle_rad) - - # Apply rotation using grid_sample - grid_x = torch.linspace(-1, 1, width, device=image.device) - grid_y = torch.linspace(-1, 1, height, device=image.device) - - # Create meshgrid - grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') - - # Expand to batch dimension - grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) - grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) - - # Apply rotation transformation - grid_x_rot = grid_x * cos_a - grid_y * sin_a - grid_y_rot = grid_x * sin_a + grid_y * cos_a - - # Stack and reshape for grid_sample - grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) - - image = torch.nn.functional.grid_sample( - image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - grid, - mode='bilinear', - padding_mode='zeros', - align_corners=False - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Color augmentations for all cameras - # Random brightness - # Use tensor operations instead of .item() for torch.compile compatibility - brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 - image = image * brightness_factor - - # Random contrast - # Use tensor operations instead of .item() for torch.compile compatibility - contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 - mean = image.mean(dim=[1, 2, 3], keepdim=True) - image = (image - mean) * contrast_factor + mean - - # Random saturation (convert to HSV, modify S, convert back) - # For simplicity, we'll just apply a random scaling to the color channels - # Use tensor operations instead of .item() for torch.compile compatibility - saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 - gray = image.mean(dim=-1, keepdim=True) - image = gray + (image - gray) * saturation_factor - - # Clamp values to [0, 1] - image = torch.clamp(image, 0, 1) - - # Back to [-1, 1] - image = image * 2.0 - 1.0 - - # Convert back to [B, C, H, W] format if it was originally channels-first - if is_channels_first: - image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] - - out_images[key] = image - - # obtain mask - out_masks = {} - for key in out_images: - if key not in observation.image_masks: - # do not mask by default - out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) - else: - out_masks[key] = observation.image_masks[key] - - return Observation( - images=out_images, - image_masks=out_masks, - state=observation.state, - tokenized_prompt=observation.tokenized_prompt, - tokenized_prompt_mask=observation.tokenized_prompt_mask, - token_ar_mask=observation.token_ar_mask, - token_loss_mask=observation.token_loss_mask, - ) - - @dataclasses.dataclass(frozen=True) class BaseModelConfig(abc.ABC): """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 3d70ca0..bda0dfd 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -7,33 +7,6 @@ from transformers.models.auto import CONFIG_MAPPING -# TODO: compare this rope vs gemma rope -def apply_rope(x, positions, max_wavelength=10_000): - """ - Applies RoPE positions [B, L] to x [B, L, H, D]. - """ - d_half = x.shape[-1] // 2 - device = x.device - dtype = x.dtype - x = x.to(torch.float32) - - freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) - timescale = max_wavelength**freq_exponents - radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) - - radians = radians[..., None, :] - - sin = torch.sin(radians) # .to(dtype=dtype) - cos = torch.cos(radians) # .to(dtype=dtype) - - x1, x2 = x.split(d_half, dim=-1) - res = torch.empty_like(x) - res[..., :d_half] = x1 * cos - x2 * sin - res[..., d_half:] = x2 * cos + x1 * sin - - return res.to(dtype) - - class PaliGemmaWithExpertModel(nn.Module): def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): super().__init__() @@ -56,7 +29,6 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" vlm_config_hf.vision_config.torch_dtype = "float32" - # vlm_config_hf.vision_config._attn_implementation = "flash_attention_2" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 8dff228..29b3045 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -106,7 +106,6 @@ def __init__(self, config): torch.set_float32_matmul_precision('high') self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") - #self.forward = torch.compile(self.forward, mode="reduce-overhead") # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False @@ -114,41 +113,18 @@ def __init__(self, config): def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True - # Enable gradient checkpointing in the underlying models - if hasattr(self.paligemma_with_expert, 'paligemma'): - if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - print("Enabled gradient checkpointing in PaliGemma language model") - - # Enable gradient checkpointing in the vision model (SigLIP) - if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True - print("Enabled gradient checkpointing in PaliGemma vision tower (SigLIP)") - - if hasattr(self.paligemma_with_expert, 'gemma_expert'): - if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): - self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True - print("Enabled gradient checkpointing in Gemma expert model") - - print("Enabled gradient checkpointing for PI0Pytorch model") + logging.info("Enabled gradient checkpointing for PI0Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False - - # Disable gradient checkpointing in the underlying models - if hasattr(self.paligemma_with_expert, 'paligemma'): - if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False - - # Disable gradient checkpointing in the vision model (SigLIP) - if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False - - if hasattr(self.paligemma_with_expert, 'gemma_expert'): - if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): - self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI0Pytorch model") @@ -156,39 +132,30 @@ def is_gradient_checkpointing_enabled(self): """Check if gradient checkpointing is enabled.""" return self.gradient_checkpointing_enabled - def get_gradient_checkpointing_status(self): - """Get detailed gradient checkpointing status of underlying models.""" - status = { - 'main_model': self.gradient_checkpointing_enabled, - 'paligemma_language_model': False, - 'paligemma_vision_model': False, - 'gemma_expert_model': False - } - - if hasattr(self.paligemma_with_expert, 'paligemma'): - if hasattr(self.paligemma_with_expert.paligemma, 'language_model'): - status['paligemma_language_model'] = getattr( - self.paligemma_with_expert.paligemma.language_model, - 'gradient_checkpointing', - False - ) - - if hasattr(self.paligemma_with_expert.paligemma, 'vision_tower'): - status['paligemma_vision_model'] = getattr( - self.paligemma_with_expert.paligemma.vision_tower, - 'gradient_checkpointing', - False - ) - - if hasattr(self.paligemma_with_expert, 'gemma_expert'): - if hasattr(self.paligemma_with_expert.gemma_expert, 'model'): - status['gemma_expert_model'] = getattr( - self.paligemma_with_expert.gemma_expert.model, - 'gradient_checkpointing', - False - ) - - return status + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + else: + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) + + def _preprocess_observation(self, observation, train=True): + """Helper method to preprocess observation.""" + observation = _model.preprocess_observation_pytorch(observation, train=train) + return ( + list(observation.images.values()), + list(observation.image_masks.values()), + observation.tokenized_prompt, + observation.tokenized_prompt_mask, + observation.state + ) def sample_noise(self, shape, device): noise = torch.normal( @@ -215,64 +182,29 @@ def embed_prefix( pad_masks = [] att_masks = [] - # Apply gradient checkpointing to image embedding if enabled - if self.gradient_checkpointing_enabled and self.training: - for ( - img, - img_mask, - ) in zip(images, img_masks): - # Use checkpoint for image embedding - def checkpointed_image_embed(img): - return self.paligemma_with_expert.embed_image(img) - - img_emb = torch.utils.checkpoint.checkpoint( - checkpointed_image_embed, - img, - use_reentrant=False, - preserve_rng_state=False - ) - - bsize, num_img_embs = img_emb.shape[:2] - img_mask = img_mask[:, None].expand(bsize, num_img_embs) - - embs.append(img_emb) - pad_masks.append(img_mask) - - # Create attention masks so that image tokens attend to each other - att_masks += [0] * num_img_embs - else: - for ( - img, - img_mask, - ) in zip(images, img_masks): - img_emb = self.paligemma_with_expert.embed_image(img) + # Process images + for img, img_mask in zip(images, img_masks): + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) - bsize, num_img_embs = img_emb.shape[:2] - img_mask = img_mask[:, None].expand(bsize, num_img_embs) + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) - embs.append(img_emb) - pad_masks.append(img_mask) + embs.append(img_emb) + pad_masks.append(img_mask) - # Create attention masks so that image tokens attend to each other - att_masks += [0] * num_img_embs + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs - # Apply gradient checkpointing to language embedding if enabled - if self.gradient_checkpointing_enabled and self.training: - def checkpointed_lang_embed(lang_tokens): - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) - - lang_emb = torch.utils.checkpoint.checkpoint( - checkpointed_lang_embed, - lang_tokens, - use_reentrant=False, - preserve_rng_state=False - ) - else: + # Process language tokens + def lang_embed_func(lang_tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) pad_masks.append(lang_masks) @@ -287,7 +219,6 @@ def checkpointed_lang_embed(lang_tokens): # Get batch size from the first dimension of the concatenated tensors bsize = pad_masks.shape[0] - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks @@ -299,19 +230,11 @@ def embed_suffix(self, state, noisy_actions, timestep): att_masks = [] if not self.pi05: - # Embed state with gradient checkpointing if enabled - if self.gradient_checkpointing_enabled and self.training: - def checkpointed_state_proj(state): - return self.state_proj(state) - - state_emb = torch.utils.checkpoint.checkpoint( - checkpointed_state_proj, - state, - use_reentrant=False, - preserve_rng_state=False - ) - else: - state_emb = self.state_proj(state) + # Embed state + def state_proj_func(state): + return self.state_proj(state) + + state_emb = self._apply_checkpoint(state_proj_func, state) embs.append(state_emb[:, None, :]) bsize = state_emb.shape[0] @@ -330,66 +253,35 @@ def checkpointed_state_proj(state): ) time_emb = time_emb.type(dtype=timestep.dtype) - # Fuse timestep + action information using an MLP with gradient checkpointing if enabled - if self.gradient_checkpointing_enabled and self.training: - def checkpointed_action_proj(noisy_actions): - return self.action_in_proj(noisy_actions) - - action_emb = torch.utils.checkpoint.checkpoint( - checkpointed_action_proj, - noisy_actions, - use_reentrant=False, - preserve_rng_state=False - ) - else: - action_emb = self.action_in_proj(noisy_actions) + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) if not self.pi05: time_emb = time_emb[:, None, :].expand_as(action_emb) action_time_emb = torch.cat([action_emb, time_emb], dim=2) - # Apply gradient checkpointing to MLP layers if enabled - if self.gradient_checkpointing_enabled and self.training: - def checkpointed_mlp(action_time_emb): - x = self.action_time_mlp_in(action_time_emb) - x = F.silu(x) # swish == silu - x = self.action_time_mlp_out(x) - return x - - action_time_emb = torch.utils.checkpoint.checkpoint( - checkpointed_mlp, - action_time_emb, - use_reentrant=False, - preserve_rng_state=False - ) - else: - action_time_emb = self.action_time_mlp_in(action_time_emb) - action_time_emb = F.silu(action_time_emb) # swish == silu - action_time_emb = self.action_time_mlp_out(action_time_emb) + # Apply MLP layers + def mlp_func(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) # swish == silu + x = self.action_time_mlp_out(x) + return x + action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) adarms_cond = None else: - # time MLP (for adaRMS) with gradient checkpointing if enabled - if self.gradient_checkpointing_enabled and self.training: - def checkpointed_time_mlp(time_emb): - x = self.time_mlp_in(time_emb) - x = F.silu(x) # swish == silu - x = self.time_mlp_out(x) - x = F.silu(x) - return x - - time_emb = torch.utils.checkpoint.checkpoint( - checkpointed_time_mlp, - time_emb, - use_reentrant=False, - preserve_rng_state=False - ) - else: - time_emb = self.time_mlp_in(time_emb) - time_emb = F.silu(time_emb) # swish == silu - time_emb = self.time_mlp_out(time_emb) - time_emb = F.silu(time_emb) + # time MLP (for adaRMS) + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) # swish == silu + x = self.time_mlp_out(x) + x = F.silu(x) + return x + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) action_time_emb = action_emb adarms_cond = time_emb @@ -412,17 +304,7 @@ def checkpointed_time_mlp(time_emb): def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - # # Use torch.compile-compatible preprocessing if we're in a compiled context - # if torch._dynamo.is_compiling(): - # observation = _model.preprocess_observation_pytorch_torch_compile(observation, train=True) - # else: - # observation = _model.preprocess_observation_pytorch(observation, train=True) - observation = _model.preprocess_observation_pytorch_torch_compile(observation, train=True) - images = list(observation.images.values()) - img_masks = list(observation.image_masks.values()) - lang_tokens = observation.tokenized_prompt - lang_masks = observation.tokenized_prompt_mask - state = observation.state + images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True) if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -444,31 +326,11 @@ def forward(self, observation, actions, noise=None, time=None) -> Tensor: att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 - # Add head dimension to attention mask: [B, seq_len, seq_len] -> [B, 1, seq_len, seq_len] - att_2d_masks_4d = att_2d_masks[:, None, :, :] - att_2d_masks_4d = torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) + # Prepare attention masks + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) # Apply gradient checkpointing if enabled - if self.gradient_checkpointing_enabled and self.training: - # Use torch.utils.checkpoint.checkpoint for the expensive forward pass - def checkpointed_forward(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks_4d, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond] - ) - return suffix_out - - suffix_out = torch.utils.checkpoint.checkpoint( - checkpointed_forward, - prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond, - use_reentrant=False, # More memory efficient - preserve_rng_state=False # More memory efficient - ) - else: + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): (_, suffix_out), _ = self.paligemma_with_expert.forward( attention_mask=att_2d_masks_4d, position_ids=position_ids, @@ -477,23 +339,20 @@ def checkpointed_forward(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids use_cache=False, adarms_cond=[None, adarms_cond] ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + ) suffix_out = suffix_out[:, -self.config.action_horizon :] suffix_out = suffix_out.to(dtype=torch.float32) # Apply gradient checkpointing to final action projection if enabled - if self.gradient_checkpointing_enabled and self.training: - def checkpointed_action_out_proj(suffix_out): - return self.action_out_proj(suffix_out) - - v_t = torch.utils.checkpoint.checkpoint( - checkpointed_action_out_proj, - suffix_out, - use_reentrant=False, - preserve_rng_state=False - ) - else: - v_t = self.action_out_proj(suffix_out) + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) losses = F.mse_loss(u_t, v_t, reduction="none") return losses @@ -501,26 +360,19 @@ def checkpointed_action_out_proj(suffix_out): @torch.no_grad() def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" - # observation = _model.preprocess_observation(observation, train=False) bsize = observation.state.shape[0] if noise is None: actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) noise = self.sample_noise(actions_shape, device) - images = list(observation.images.values()) - img_masks = list(observation.image_masks.values()) - lang_tokens = observation.tokenized_prompt - lang_masks = observation.tokenized_prompt_mask - state = observation.state + images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 # Compute image and language key value cache - # Add head dimension to attention mask: [B, seq_len, seq_len] -> [B, 1, seq_len, seq_len] - prefix_att_2d_masks_4d = prefix_att_2d_masks[:, None, :, :] - prefix_att_2d_masks_4d = torch.where(prefix_att_2d_masks_4d, 0.0, -2.3819763e38) + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" _, past_key_values = self.paligemma_with_expert.forward( @@ -575,9 +427,8 @@ def denoise_step( prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - # Add head dimension to attention mask: [B, seq_len, seq_len] -> [B, 1, seq_len, seq_len] - full_att_2d_masks_4d = full_att_2d_masks[:, None, :, :] - full_att_2d_masks_4d = torch.where(full_att_2d_masks_4d, 0.0, -2.3819763e38) + # Prepare attention masks + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" outputs_embeds, _ = self.paligemma_with_expert.forward( diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index e1ec12c..70bbdbf 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -724,7 +724,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=64, + batch_size=256, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -734,7 +734,7 @@ def __post_init__(self) -> None: optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader=weight_loaders.CheckpointWeightLoader( - "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" + "gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/params" ), num_train_steps=30_000, ), @@ -746,7 +746,8 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=1024, + #batch_size=1024, # 2 nodes, 16 H100s + batch_size=64, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=2_500, peak_lr=1e-4, @@ -755,6 +756,7 @@ def __post_init__(self) -> None: ), optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, + #weight_loader="/path/to/pi05_libero_pytorch", weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", num_train_steps=6_500, ), diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index 8355336..0832861 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp import lerobot.common.datasets.lerobot_dataset as lerobot_dataset +import logging import numpy as np import torch @@ -228,7 +229,7 @@ def create_data_loader( ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training.""" data_config = config.data.create(config.assets_dirs, config.model) - print(f"data_config: {data_config}") + logging.info(f"data_config: {data_config}") if data_config.rlds_data_dir is not None: return create_rlds_data_loader( From b93f3636e1c5d55f04ed348a98a250777430458a Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 27 Aug 2025 08:25:42 -0700 Subject: [PATCH 18/32] try float32 --- examples/convert_jax_model_to_pytorch.py | 13 +- scripts/train_pytorch.py | 30 +- scripts/train_single_example.py | 447 ++++++++++++++---- src/openpi/models_pytorch/gemma_pytorch.py | 26 +- src/openpi/models_pytorch/pi0_pytorch.py | 10 +- .../models/gemma/modeling_gemma.py | 2 +- .../models/siglip/modeling_siglip.py | 2 +- src/openpi/training/config.py | 93 +++- 8 files changed, 481 insertions(+), 142 deletions(-) diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index a1aa768..fbcac9b 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -338,8 +338,12 @@ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | } restore_dtype = dtype_map.get(restore_precision) if restore_precision else None + # Use CPU sharding to avoid GPU memory issues during checkpoint loading + cpu_device = jax.devices('cpu')[0] + cpu_sharding = jax.sharding.SingleDeviceSharding(cpu_device) + # Use repository restore utility to load a pure dict of params (value suffix removed) - params = openpi.models.model.restore_params(params_dir, restore_type=jax.Array, dtype=restore_dtype) + params = openpi.models.model.restore_params(params_dir, restore_type=jax.Array, dtype=restore_dtype, sharding=cpu_sharding) # get params for PaliGemma pali_params = params["PaliGemma"] @@ -384,7 +388,8 @@ def load_jax_model_and_print_keys(checkpoint_dir: str): return item = {params_name: metadata[params_name]} - device = jax.local_devices()[0] + # Use CPU device to avoid GPU memory issues + device = jax.devices('cpu')[0] sharding = jax.sharding.SingleDeviceSharding(device) restored = checkpointer.restore( @@ -539,7 +544,7 @@ def __init__(self): pi05=True, ) elif "pi05_base" in checkpoint_dir: - pi0_config = Pi0Config( + pi0_config = openpi.models.pi0_config.Pi0Config( action_dim=32, action_horizon=50, pi05=True, @@ -563,7 +568,7 @@ def __init__(self): print(f"Warning: Could not load all parameters: {e}") print("Continuing with partial load...") - pi0_model = pi0_model.to(torch.bfloat16) + pi0_model = pi0_model.to(torch.float32) # Save the converted model using safetensors os.makedirs(output_path, exist_ok=True) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index e6d4a4e..44cbdcd 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -396,7 +396,9 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien # Log sample images to wandb on first batch if is_main and config.wandb_enabled and not resuming: - sample_batch = next(iter(loader)) + # Create a separate iterator for sample batch to avoid consuming the main loader + sample_loader = torch.utils.data.DataLoader(dataset, batch_size=effective_batch_size, shuffle=False, sampler=sampler, num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_to_numpy) + sample_batch = next(iter(sample_loader)) sample_batch = batch_to_torch(sample_batch, device) # Create sample images for wandb @@ -414,9 +416,6 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien # Clear sample batch from memory torch.cuda.empty_cache() if torch.cuda.is_available() else None - # Reset the loader iterator - loader = torch.utils.data.DataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_to_numpy) - # Build model if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): # Convert dataclass to Pi0Config if needed @@ -490,7 +489,9 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien def lr_schedule(step: int): if step < warmup_steps: - return peak_lr * (step + 1) / warmup_steps + # Match JAX behavior: start from peak_lr / (warmup_steps + 1) + init_lr = peak_lr / (warmup_steps + 1) + return init_lr + (peak_lr - init_lr) * step / warmup_steps # cosine decay progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) cos = 0.5 * (1 + np.cos(np.pi * progress)) @@ -548,7 +549,7 @@ def lr_schedule(step: int): log_memory_usage(device, global_step, "after_backward") # Gradient clipping - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) # Optimizer step optim.step() @@ -577,6 +578,7 @@ def lr_schedule(step: int): infos.append({ "loss": loss.item(), "learning_rate": optim.param_groups[0]['lr'], + "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, }) if is_main and (global_step % config.log_interval == 0): @@ -586,16 +588,24 @@ def lr_schedule(step: int): avg_loss = sum(info["loss"] for info in infos) / len(infos) avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) - logging.info(f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s") + avg_grad_norm = None + if any('grad_norm' in info for info in infos): + vals = [info['grad_norm'] for info in infos if 'grad_norm' in info and info['grad_norm'] is not None] + if len(vals) > 0: + avg_grad_norm = sum(vals) / len(vals) + logging.info(f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" if avg_grad_norm is not None else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s") # Log to wandb - if config.wandb_enabled and len(infos) > 1: - wandb.log({ + if config.wandb_enabled and len(infos) > 0: + log_payload = { "loss": avg_loss, "learning_rate": avg_lr, "step": global_step, "time_per_step": elapsed / config.log_interval, - }, step=global_step) + } + if avg_grad_norm is not None: + log_payload["grad_norm"] = avg_grad_norm + wandb.log(log_payload, step=global_step) start_time = time.time() infos = [] # Reset stats collection diff --git a/scripts/train_single_example.py b/scripts/train_single_example.py index a2e90bb..80dc712 100644 --- a/scripts/train_single_example.py +++ b/scripts/train_single_example.py @@ -16,20 +16,42 @@ # Test pi0_droid model python scripts/train_single_example.py \ --model_name pi0_droid \ - --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid \ - --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi0_droid \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi0_droid_pytorch # Test pi0_aloha_sim model python scripts/train_single_example.py \ --model_name pi0_aloha_sim \ - --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim \ - --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi0_aloha_sim \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi0_aloha_sim_pytorch + +# Test pi05_libero model with pickle file (FASTEST for debugging) +python scripts/train_single_example.py \ + --model_name pi05_libero \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_libero \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_libero_pytorch \ + --load_pickle ./libero_sample.pkl + +# Test pi05_libero model with small dataset (RECOMMENDED for first-time setup) +python scripts/train_single_example.py \ + --model_name pi05_libero \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_libero \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_libero_pytorch \ + +Data Loading Options (in order of speed): +1. --load_pickle: Load from pickle file (instant, fastest for repeated debugging) + + +Setup Workflow: +1. First time: Run save_libero_sample.py to create pickle file (takes 10-30 minutes) +2. Subsequent runs: Use --load_pickle for instant loading (takes ~1 second) This script: -- Creates a fixed example with deterministic random values +- Loads a real example from the pi05_libero dataset using JAX data loader - Uses the same noise and time values for both JAX and PyTorch - Disables preprocessing for fair comparison - Compares losses between implementations +- Tests forward pass, backward pass, and another forward pass - Provides detailed analysis of differences """ @@ -43,11 +65,18 @@ import flax import safetensors from unittest.mock import patch +import time +import signal +import sys +import optax -from openpi.models import model as _model +import openpi.models.model as _model from openpi.models.pi0_config import Pi0Config from openpi.models_pytorch.pi0_pytorch import PI0Pytorch import openpi.training.config +from openpi.training.data_loader import create_data_loader +import openpi.training.optimizer +from openpi.shared import nnx_utils def setup_logging(): @@ -58,6 +87,46 @@ def setup_logging(): ) +def load_sample_from_pickle(pickle_path: str): + """Load a sample from a pickle file saved by save_libero_sample.py.""" + print(f"šŸ”„ Loading sample from pickle file: {pickle_path}") + start_time = time.time() + + try: + import pickle + with open(pickle_path, 'rb') as f: + sample_data = pickle.load(f) + + elapsed_time = time.time() - start_time + print(f"āœ… Sample loaded successfully in {elapsed_time:.2f} seconds") + + # Extract the data in the same format as create_fixed_example + observation_data = sample_data['observation'] + + # Return in the same format as create_fixed_example + example = { + "image": observation_data.get("image", {}), + "image_mask": observation_data.get("image_mask", {}), + "state": observation_data.get("state", np.array([])), + "actions": observation_data.get("actions", np.array([])), + "tokenized_prompt": observation_data.get("tokenized_prompt", np.array([])), + "tokenized_prompt_mask": observation_data.get("tokenized_prompt_mask", np.array([])), + } + + print(f" - Image keys: {list(example['image'].keys()) if example['image'] else 'None'}") + print(f" - State shape: {example['state'].shape}") + print(f" - Actions shape: {example['actions'].shape}") + + return example + + except Exception as e: + elapsed_time = time.time() - start_time + print(f"āŒ Failed to load pickle file after {elapsed_time:.2f} seconds: {e}") + import traceback + traceback.print_exc() + raise + + def create_fixed_example(model_config): """Create a fixed example for debugging.""" np.random.seed(42) @@ -126,7 +195,7 @@ def mock_preprocess_observation_pytorch(observation, **kwargs): return observation -def test_pytorch_single_example(noise, time, model_name, pytorch_checkpoint_dir): +def test_pytorch_single_example(noise, time, model_name, pytorch_checkpoint_dir, load_pickle=None): """Test PyTorch training on single example.""" print("\n=== Testing PyTorch on Single Example ===") @@ -140,8 +209,13 @@ def test_pytorch_single_example(noise, time, model_name, pytorch_checkpoint_dir) safetensors.torch.load_model(model, pytorch_checkpoint_dir) - # Create fixed example - example = create_fixed_example(train_config.model) + # Load data based on arguments + if load_pickle: + print(f"šŸŽÆ Loading sample from pickle file: {load_pickle}") + example = load_sample_from_pickle(load_pickle) + else: + print("šŸŽÆ Using fixed example for testing") + example = create_fixed_example(train_config.model) # Convert to PyTorch tensors pytorch_example = {} @@ -164,35 +238,111 @@ def test_pytorch_single_example(noise, time, model_name, pytorch_checkpoint_dir) time_tensor = torch.from_numpy(time) # Create observation - observation = _model.Observation.from_dict(pytorch_example) - actions = pytorch_example["actions"] + observation_torch = _model.Observation.from_dict(pytorch_example) + actions_torch = pytorch_example["actions"] - print(f"Observation state shape: {observation.state.shape}") - print(f"Observation state dtype: {observation.state.dtype}") - print(f"Actions shape: {actions.shape}") - print(f"Actions dtype: {actions.dtype}") + print(f"Observation state shape: {observation_torch.state.shape}") + print(f"Observation state dtype: {observation_torch.state.dtype}") + print(f"Actions shape: {actions_torch.shape}") + print(f"Actions dtype: {actions_torch.dtype}") print(f"Noise shape: {noise_tensor.shape}, dtype: {noise_tensor.dtype}") print(f"Time shape: {time_tensor.shape}, dtype: {time_tensor.dtype}") - # Test forward pass with fixed noise and time - model.eval() - with torch.no_grad(): - try: - # Use mock to disable preprocessing - with patch('openpi.models.model.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): - losses = model(observation, actions, noise=noise_tensor, time=time_tensor) - print(f"PyTorch forward pass successful!") - print(f"Losses shape: {losses.shape}") - print(f"Losses dtype: {losses.dtype}") - mean_loss = losses.to(torch.float32).mean().item() - print(f"Mean loss: {mean_loss:.6f}") - return True, losses - except Exception as e: - print(f"PyTorch forward pass failed: {e}") - return False, None - - -def test_jax_single_example(noise, time, model_name, jax_checkpoint_dir): + # Setup optimizer from config + print(f"Setting up optimizer from config: {type(train_config.optimizer).__name__}") + # Use the exact same optimizer creation code as in train_pytorch.py + warmup_steps = train_config.lr_schedule.warmup_steps + peak_lr = train_config.lr_schedule.peak_lr + decay_steps = train_config.lr_schedule.decay_steps + end_lr = train_config.lr_schedule.decay_lr + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=peak_lr, + betas=(train_config.optimizer.b1, train_config.optimizer.b2), + eps=train_config.optimizer.eps, + weight_decay=train_config.optimizer.weight_decay + ) + print(f"āœ… Optimizer created: {type(optimizer).__name__}") + print(f" - Learning rate: {peak_lr}") + print(f" - Betas: ({train_config.optimizer.b1}, {train_config.optimizer.b2})") + print(f" - Epsilon: {train_config.optimizer.eps}") + print(f" - Weight decay: {train_config.optimizer.weight_decay}") + + # Define learning rate schedule function (same as in train_pytorch.py) + def lr_schedule(step: int): + if step < warmup_steps: + # Match JAX behavior: start from peak_lr / (warmup_steps + 1) + init_lr = peak_lr / (warmup_steps + 1) + return init_lr + (peak_lr - init_lr) * step / warmup_steps + # cosine decay + progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) + cos = 0.5 * (1 + np.cos(np.pi * progress)) + return end_lr + (peak_lr - end_lr) * cos + + # Test forward pass, backward pass, and another forward pass + model.train() # Set to training mode for backward pass + + try: + # Use mock to disable preprocessing + with patch('openpi.models.model.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): + # First forward pass + print("šŸ”„ First forward pass...") + losses_1 = model(observation_torch, actions_torch, noise=noise_tensor, time=time_tensor) + loss_1 = losses_1.mean() + print(f"First forward pass successful! Loss: {loss_1.item():.6f}") + + # Backward pass + print("šŸ”„ Backward pass...") + loss_1.backward() + print("Backward pass successful!") + + # Check gradients + total_grad_norm = 0 + param_count = 0 + for param in model.parameters(): + if param.grad is not None: + total_grad_norm += param.grad.norm().item() ** 2 + param_count += 1 + total_grad_norm = total_grad_norm ** 0.5 + print(f"Gradient norm: {total_grad_norm:.6f} (from {param_count} parameters)") + + # Gradient clipping (same as in train_pytorch.py) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=train_config.optimizer.clip_gradient_norm) + print(f" - Gradient clipping applied, clipped norm: {grad_norm:.6f}") + + # Optimizer step to update parameters + print("šŸ”„ Optimizer step...") + # Update learning rate using the schedule (same as in train_pytorch.py) + current_lr = lr_schedule(0) # Use step 0 for this test + for pg in optimizer.param_groups: + pg["lr"] = current_lr + print(f" - Updated learning rate to: {current_lr:.2e}") + + optimizer.step() + print("Optimizer step successful!") + + # Clear gradients for next iteration + optimizer.zero_grad() + + # Second forward pass + print("šŸ”„ Second forward pass...") + losses_2 = model(observation_torch, actions_torch, noise=noise_tensor, time=time_tensor) + loss_2 = losses_2.mean() + print(f"Second forward pass successful! Loss: {loss_2.item():.6f}") + + # Compare losses + loss_diff = abs(loss_1.item() - loss_2.item()) + print(f"Loss difference between passes: {loss_diff:.8f}") + + return True, losses_1, losses_2 + + except Exception as e: + print(f"PyTorch forward/backward pass failed: {e}") + return False, None, None + + +def test_jax_single_example(noise, time, model_name, jax_checkpoint_dir, load_pickle=None): """Test JAX training on single example.""" print("\n=== Testing JAX on Single Example ===") @@ -201,6 +351,7 @@ def test_jax_single_example(noise, time, model_name, jax_checkpoint_dir): rng = jax.random.key(42) model = train_config.model.create(rng) # Use train_config.model instead of config.create + jax_checkpoint_dir = jax_checkpoint_dir + '/params/' # Load pre-trained weights print(f"Loading JAX weights from: {jax_checkpoint_dir}") params = _model.restore_params(jax_checkpoint_dir, dtype=jnp.bfloat16) @@ -221,8 +372,13 @@ def test_jax_single_example(noise, time, model_name, jax_checkpoint_dir): print("šŸ”„ Continuing with random initialization...") model = nnx.merge(graphdef, model_state) - # Create fixed example - example = create_fixed_example(train_config.model) + # Load data based on arguments + if load_pickle: + print(f"šŸŽÆ Loading sample from pickle file: {load_pickle}") + example = load_sample_from_pickle(load_pickle) + else: + print("šŸŽÆ Using fixed example for testing") + example = create_fixed_example(train_config.model) # Convert to JAX arrays jax_example = {} @@ -249,110 +405,192 @@ def test_jax_single_example(noise, time, model_name, jax_checkpoint_dir): print(f"Noise shape: {noise_jax.shape}, dtype: {noise_jax.dtype}") print(f"Time shape: {time_jax.shape}, dtype: {time_jax.dtype}") - # Test forward pass with fixed noise and time + # Get learning rate from config + lr = train_config.lr_schedule.peak_lr if hasattr(train_config.lr_schedule, 'peak_lr') else 1e-4 + print(f"Using learning rate from config: {lr}") + + # Create optimizer using the exact same code as in train.py + print("Setting up JAX optimizer from config...") + tx = openpi.training.optimizer.create_optimizer(train_config.optimizer, train_config.lr_schedule, weight_decay_mask=None) + print(f"āœ… JAX optimizer created: {type(tx).__name__}") + + # Initialize optimizer state + # Get trainable parameters from the model using NNX + params = nnx.state(model) + trainable_params = params.filter(train_config.trainable_filter) + opt_state = tx.init(trainable_params) + print(f"āœ… Optimizer state initialized") + + # Test forward pass, backward pass, and another forward pass try: - # Use the modified compute_loss method that accepts external noise and time # Use mock to disable preprocessing with patch('openpi.models.model.preprocess_observation', side_effect=mock_preprocess_observation): - losses = model.compute_loss(rng, observation, actions, train=False, noise=noise_jax, time=time_jax) - print(f"JAX forward pass successful!") - print(f"Losses shape: {losses.shape}") - print(f"Losses dtype: {losses.dtype}") - mean_loss = jnp.mean(losses).item() - print(f"Mean loss: {mean_loss:.6f}") - return True, losses + # JIT compile the compute_loss method for memory efficiency + print("šŸ”„ JIT compiling compute_loss method...") + jitted_compute_loss = nnx_utils.module_jit(model.compute_loss) + + # First forward pass + print("šŸ”„ First forward pass...") + losses_1 = jitted_compute_loss(rng, observation, actions, train=True, noise=noise_jax, time=time_jax) + loss_1 = losses_1.mean() + print(f"First forward pass successful! Loss: {loss_1.item():.6f}") + + # Use the same approach as in train.py for gradient computation and parameter updates + print("šŸ”„ Computing gradients and updating parameters...") + + # Define loss function for gradient computation using JIT compiled method + def loss_fn(model, rng, observation, actions): + chunked_loss = jitted_compute_loss(rng, observation, actions, train=True, noise=noise_jax, time=time_jax) + return jnp.mean(chunked_loss) + + # Filter out frozen params and compute gradients (same as train.py) + diff_state = nnx.DiffState(0, train_config.trainable_filter) + loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, rng, observation, actions) + + print("Gradients computed successfully!") + + # Check gradient norms using JAX's global_norm + try: + grad_norm = optax.global_norm(grads) + print(f"Gradient norm: {grad_norm.item():.6f}") + except Exception as e: + print(f"Could not compute gradient norm: {e}") + print("Continuing without gradient norm...") + + # Update parameters using optimizer (same as train.py) + print("šŸ”„ Updating parameters with optimizer...") + params = nnx.state(model).filter(train_config.trainable_filter) + updates, new_opt_state = tx.update(grads, opt_state, params) + new_params = optax.apply_updates(params, updates) + + # Update the model in place (same as train.py) + nnx.update(model, new_params) + opt_state = new_opt_state + print("Parameter update successful!") + + # Second forward pass + print("šŸ”„ Second forward pass...") + losses_2 = jitted_compute_loss(rng, observation, actions, train=True, noise=noise_jax, time=time_jax) + loss_2 = losses_2.mean() + print(f"Second forward pass successful! Loss: {loss_2.item():.6f}") + + # Compare losses + loss_diff = abs(loss_1.item() - loss_2.item()) + print(f"Loss difference between passes: {loss_diff:.8f}") + + return True, losses_1, losses_2 + except Exception as e: - print(f"JAX forward pass failed: {e}") - return False, None + print(f"JAX forward/backward pass failed: {e}") + import traceback + traceback.print_exc() + return False, None, None -def compare_losses(pytorch_loss, jax_loss): +def compare_losses(pytorch_loss_1, pytorch_loss_2, jax_loss_1, jax_loss_2): """Compare losses and compute relative differences.""" - if pytorch_loss is None or jax_loss is None: + if pytorch_loss_1 is None or jax_loss_1 is None: return print("\n" + "=" * 70) print("šŸ“Š LOSS COMPARISON") print("=" * 70) - print(f"PyTorch loss: {pytorch_loss}") - print(f"JAX loss: {jax_loss}") + print(f"PyTorch first pass loss: {pytorch_loss_1}") + print(f"PyTorch second pass loss: {pytorch_loss_2}") + print(f"JAX first pass loss: {jax_loss_1}") + print(f"JAX second pass loss: {jax_loss_2}") # Additional tensor analysis if both are tensors - pytorch_loss = pytorch_loss.to(torch.float32) - jax_loss = jax_loss.astype(jnp.float32) - if hasattr(pytorch_loss, 'shape') and hasattr(jax_loss, 'shape'): + pytorch_loss_1 = pytorch_loss_1.to(torch.float32) + pytorch_loss_2 = pytorch_loss_2.to(torch.float32) + jax_loss_1 = jax_loss_1.astype(jnp.float32) + jax_loss_2 = jax_loss_2.astype(jnp.float32) + + if hasattr(pytorch_loss_1, 'shape') and hasattr(jax_loss_1, 'shape'): print(f"\nšŸ“ Tensor Analysis:") # Check if shapes match - if pytorch_loss.shape == jax_loss.shape: - print(f"āœ… Tensor shapes match: {pytorch_loss.shape}") + if pytorch_loss_1.shape == jax_loss_1.shape: + print(f"āœ… Tensor shapes match: {pytorch_loss_1.shape}") # Element-wise comparison - if hasattr(pytorch_loss, 'flatten') and hasattr(jax_loss, 'flatten'): + if hasattr(pytorch_loss_1, 'flatten') and hasattr(jax_loss_1, 'flatten'): # Convert to numpy for element-wise analysis try: - pytorch_flat = pytorch_loss.detach().cpu().numpy().flatten() - jax_flat = jax_loss.flatten() - - # Element-wise differences - element_diff = np.abs(pytorch_flat - jax_flat) - print(f"element_diff[0]: {element_diff[0:2048*816:2048]}") - max_element_diff = np.max(element_diff) - mean_element_diff = np.mean(element_diff) - - print(f" Max element-wise difference: {max_element_diff:.8f}") - print(f" Mean element-wise difference: {mean_element_diff:.8f}") + pytorch_flat_1 = pytorch_loss_1.detach().cpu().numpy().flatten() + pytorch_flat_2 = pytorch_loss_2.detach().cpu().numpy().flatten() + jax_flat_1 = jax_loss_1.flatten() + jax_flat_2 = jax_loss_2.flatten() + + # Element-wise differences between implementations + element_diff_1 = np.abs(pytorch_flat_1 - jax_flat_1) + element_diff_2 = np.abs(pytorch_flat_2 - jax_flat_2) + + max_element_diff_1 = np.max(element_diff_1) + mean_element_diff_1 = np.mean(element_diff_1) + max_element_diff_2 = np.max(element_diff_2) + mean_element_diff_2 = np.mean(element_diff_2) + + print(f" First pass - Max element-wise difference: {max_element_diff_1:.8f}") + print(f" First pass - Mean element-wise difference: {mean_element_diff_1:.8f}") + print(f" Second pass - Max element-wise difference: {max_element_diff_2:.8f}") + print(f" Second pass - Mean element-wise difference: {mean_element_diff_2:.8f}") # Element-wise relative differences # Avoid division by zero by adding small epsilon epsilon = 1e-12 - pytorch_flat_safe = pytorch_flat + epsilon - jax_flat_safe = jax_flat + epsilon + pytorch_flat_1_safe = pytorch_flat_1 + epsilon + jax_flat_1_safe = jax_flat_1 + epsilon + pytorch_flat_2_safe = pytorch_flat_2 + epsilon + jax_flat_2_safe = jax_flat_2 + epsilon # Compute relative differences for each element - rel_diff_pytorch_elements = (element_diff / np.abs(pytorch_flat_safe)) * 100 - rel_diff_jax_elements = (element_diff / np.abs(jax_flat_safe)) * 100 + rel_diff_pytorch_1 = (element_diff_1 / np.abs(pytorch_flat_1_safe)) * 100 + rel_diff_jax_1 = (element_diff_1 / np.abs(jax_flat_1_safe)) * 100 + rel_diff_pytorch_2 = (element_diff_2 / np.abs(pytorch_flat_2_safe)) * 100 + rel_diff_jax_2 = (element_diff_2 / np.abs(jax_flat_2_safe)) * 100 # Compute mean of relative differences - mean_rel_diff_pytorch = np.mean(rel_diff_pytorch_elements) - mean_rel_diff_jax = np.mean(rel_diff_jax_elements) + mean_rel_diff_pytorch_1 = np.mean(rel_diff_pytorch_1) + mean_rel_diff_jax_1 = np.mean(rel_diff_jax_1) + mean_rel_diff_pytorch_2 = np.mean(rel_diff_pytorch_2) + mean_rel_diff_jax_2 = np.mean(rel_diff_jax_2) - print(f" Mean relative difference (w.r.t. PyTorch elements): {mean_rel_diff_pytorch:.4f}%") - print(f" Mean relative difference (w.r.t. JAX elements): {mean_rel_diff_jax:.4f}%") + print(f" First pass - Mean relative difference (w.r.t. PyTorch): {mean_rel_diff_pytorch_1:.4f}%") + print(f" First pass - Mean relative difference (w.r.t. JAX): {mean_rel_diff_jax_1:.4f}%") + print(f" Second pass - Mean relative difference (w.r.t. PyTorch): {mean_rel_diff_pytorch_2:.4f}%") + print(f" Second pass - Mean relative difference (w.r.t. JAX): {mean_rel_diff_jax_2:.4f}%") # Count elements with significant differences significant_threshold = 1e-4 - significant_count = np.sum(element_diff > significant_threshold) - total_elements = len(element_diff) - significant_percentage = (significant_count / total_elements) * 100 - - print(f" Elements with diff > {significant_threshold}: {significant_count}/{total_elements} ({significant_percentage:.2f}%)") + significant_count_1 = np.sum(element_diff_1 > significant_threshold) + significant_count_2 = np.sum(element_diff_2 > significant_threshold) + total_elements = len(element_diff_1) + significant_percentage_1 = (significant_count_1 / total_elements) * 100 + significant_percentage_2 = (significant_count_2 / total_elements) * 100 - # Additional relative difference analysis - significant_rel_threshold = 1.0 # 1% - significant_rel_count_pytorch = np.sum(rel_diff_pytorch_elements > significant_rel_threshold) - significant_rel_count_jax = np.sum(rel_diff_jax_elements > significant_rel_threshold) - - print(f" Elements with rel diff > {significant_rel_threshold}% (w.r.t. PyTorch): {significant_rel_count_pytorch}/{total_elements} ({(significant_rel_count_pytorch/total_elements)*100:.2f}%)") - print(f" Elements with rel diff > {significant_rel_threshold}% (w.r.t. JAX): {significant_rel_count_jax}/{total_elements} ({(significant_rel_count_jax/total_elements)*100:.2f}%)") + print(f" First pass - Elements with diff > {significant_threshold}: {significant_count_1}/{total_elements} ({significant_percentage_1:.2f}%)") + print(f" Second pass - Elements with diff > {significant_threshold}: {significant_count_2}/{total_elements} ({significant_percentage_2:.2f}%)") except Exception as e: print(f" āš ļø Could not perform element-wise analysis: {e}") else: - print(f"āŒ Tensor shapes don't match: PyTorch {pytorch_loss.shape} vs JAX {jax_loss.shape}") + print(f"āŒ Tensor shapes don't match: PyTorch {pytorch_loss_1.shape} vs JAX {jax_loss_1.shape}") def main(): """Main function to test both implementations.""" parser = argparse.ArgumentParser(description="Train on a single example for JAX vs PyTorch comparison") - parser.add_argument("--model_name", type=str, default="pi05_droid", + parser.add_argument("--model_name", type=str, default="pi05_libero", choices=["pi0_aloha_sim", "pi0_aloha_towel", "pi0_base", "pi05_droid", "pi0_droid", "pi0_libero", "pi05_libero"], help="Model name to use") parser.add_argument("--jax_checkpoint_dir", type=str, required=True, help="Directory containing JAX model checkpoints") parser.add_argument("--pytorch_checkpoint_dir", type=str, required=True, help="Directory containing PyTorch model checkpoints") + parser.add_argument("--load_pickle", type=str, default=None, + help="Load sample from pickle file (fastest option)") args = parser.parse_args() setup_logging() @@ -362,7 +600,7 @@ def main(): print(f"šŸ“ Model: {args.model_name}") print(f"šŸ“ JAX checkpoint: {args.jax_checkpoint_dir}") print(f"šŸ“ PyTorch checkpoint: {args.pytorch_checkpoint_dir}") - print("šŸŽÆ Using fixed noise and time values for deterministic comparison...") + print("🚫 Preprocessing disabled: Image augmentations and resizing are bypassed for fair comparison...") # Get model configuration @@ -376,16 +614,25 @@ def main(): action_dim=model_config.action_dim ) + # Test JAX + jax_success, jax_losses_1, jax_losses_2 = test_jax_single_example( + noise, time, args.model_name, args.jax_checkpoint_dir, args.load_pickle + ) + + # Clear JAX memory + jax.clear_caches() + # Test PyTorch - pytorch_success, pytorch_losses = test_pytorch_single_example(noise, time, args.model_name, args.pytorch_checkpoint_dir) + pytorch_success, pytorch_losses_1, pytorch_losses_2 = test_pytorch_single_example( + noise, time, args.model_name, args.pytorch_checkpoint_dir, args.load_pickle + ) torch.cuda.empty_cache() - # Test JAX - jax_success, jax_losses = test_jax_single_example(noise, time, args.model_name, args.jax_checkpoint_dir) + # Compare losses if pytorch_success and jax_success: - compare_losses(pytorch_losses, jax_losses) + compare_losses(pytorch_losses_1, pytorch_losses_2, jax_losses_1, jax_losses_2) # Summary print("\n" + "=" * 70) @@ -394,6 +641,7 @@ def main(): if pytorch_success and jax_success: print("āœ… Both JAX and PyTorch implementations work on the single example!") + print("āœ… Forward pass, backward pass, and second forward pass completed successfully!") print("šŸ” Loss comparison completed above.") elif pytorch_success: print("āŒ PyTorch works but JAX failed. Check JAX implementation.") @@ -405,10 +653,11 @@ def main(): print("\nšŸ’” Next steps:") print("1. Run this script to verify both implementations work") print("2. Analyze the loss comparison results above") - print("3. If losses differ significantly, investigate the differences") - print("4. Check if the noise and time handling is consistent between implementations") - print("5. Use the same example in full training runs") - print("6. Note: Preprocessing (image augmentations) is disabled for this comparison") + print("3. Check if the backward pass gradients are reasonable") + print("4. If losses differ significantly, investigate the differences") + print("5. Check if the noise and time handling is consistent between implementations") + print("6. Use the same example in full training runs") + print("7. Note: Preprocessing (image augmentations) is disabled for this comparison") if __name__ == "__main__": diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index bda0dfd..c91ed21 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -51,20 +51,20 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): self.to_bfloat16_for_selected_params() def to_bfloat16_for_selected_params(self): - self = self.to(dtype=torch.bfloat16) + self = self.to(dtype=torch.float32) - params_to_keep_float32 = [ - "vision_tower.vision_model.embeddings.patch_embedding.weight", - "vision_tower.vision_model.embeddings.patch_embedding.bias", - "vision_tower.vision_model.embeddings.position_embedding.weight", - "input_layernorm", - "post_attention_layernorm", - "model.norm", - ] + # params_to_keep_float32 = [ + # "vision_tower.vision_model.embeddings.patch_embedding.weight", + # "vision_tower.vision_model.embeddings.patch_embedding.bias", + # "vision_tower.vision_model.embeddings.position_embedding.weight", + # "input_layernorm", + # "post_attention_layernorm", + # "model.norm", + # ] - for name, param in self.named_parameters(): - if any(selector in name for selector in params_to_keep_float32): - param.data = param.data.to(dtype=torch.float32) + # for name, param in self.named_parameters(): + # if any(selector in name for selector in params_to_keep_float32): + # param.data = param.data.to(dtype=torch.float32) def embed_image(self, image: torch.Tensor): return self.paligemma.model.get_image_features(image) @@ -194,7 +194,7 @@ def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_id out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) after_first_residual = out_emb.clone() out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) - out_emb = out_emb.to(dtype=torch.bfloat16) + #out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 29b3045..7928d66 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -44,9 +44,11 @@ def create_sinusoidal_pos_embedding( def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + samples = dist.sample((bsize,)) + return samples def make_att_2d_masks(pad_masks, att_masks): @@ -318,7 +320,7 @@ def forward(self, observation, actions, noise=None, time=None) -> Tensor: prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) - suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + #suffix_embs = suffix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) diff --git a/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py b/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py index 3ad389a..434c3fa 100644 --- a/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py @@ -498,7 +498,7 @@ def forward( # embed positions hidden_states = inputs_embeds - hidden_states = hidden_states.to(torch.bfloat16) + #hidden_states = hidden_states.to(torch.bfloat16) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) diff --git a/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py b/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py index d7dc902..077571e 100644 --- a/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py @@ -773,7 +773,7 @@ def forward( ) hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - hidden_states = hidden_states.to(torch.bfloat16) + #hidden_states = hidden_states.to(torch.bfloat16) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 70bbdbf..7a9c2c4 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -379,7 +379,7 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig ) data_transforms = _transforms.Group( - inputs=[droid_policy.DroidInputs(action_dim=model_config.action_dim, model_type=model_config.model_type)], + inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)], outputs=[droid_policy.DroidOutputs()], ) @@ -734,7 +734,7 @@ def __post_init__(self) -> None: optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader=weight_loaders.CheckpointWeightLoader( - "gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/params" + "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base" ), num_train_steps=30_000, ), @@ -746,19 +746,22 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - #batch_size=1024, # 2 nodes, 16 H100s - batch_size=64, + batch_size=256, # 2 nodes, 16 H100s lr_schedule=_optimizer.CosineDecaySchedule( - warmup_steps=2_500, - peak_lr=1e-4, + warmup_steps=10_000, + peak_lr=5e-5, decay_steps=1_000_000, - decay_lr=1e-4, + decay_lr=5e-5, ), - optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), + optimizer=_optimizer.AdamW(clip_gradient_norm=1000.0), ema_decay=0.999, #weight_loader="/path/to/pi05_libero_pytorch", - weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", - num_train_steps=6_500, + # weight_loader=weight_loaders.CheckpointWeightLoader( + # "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" + # ), + #weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", + weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch_float32", + num_train_steps=30_000, ), # # Fine-tuning Aloha configs. @@ -859,6 +862,76 @@ def __post_init__(self) -> None: keep_period=20_000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), + TrainConfig( + # This config is for fine-tuning pi05 on the *full* DROID dataset. + # We use RLDS data loading to make training on this large dataset tractable. + # For fine-tuning on your own DROID dataset, see below. + name="pi05_full_droid_finetune", + model=pi0_config.Pi0Config( + pi05=True, + action_dim=32, + action_horizon=16, + ), + data=RLDSDroidDataConfig( + repo_id="droid", + # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). + rlds_data_dir="/mnt/pi-data/kevin", + action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, + assets=AssetsConfig( + assets_dir="gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/assets/", + asset_id="droid", + ), + ), + weight_loader=weight_loaders.CheckpointWeightLoader( + "gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/params" + ), + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=1_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + num_train_steps=100_000, + batch_size=256, + log_interval=100, + save_interval=5000, + keep_period=10_000, + num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally + ), + TrainConfig( + # This config is for fine-tuning pi05 on the *full* DROID dataset. + # We use RLDS data loading to make training on this large dataset tractable. + # For fine-tuning on your own DROID dataset, see below. + name="pi05_full_droid_finetune_pytorch", + model=pi0_config.Pi0Config( + pi05=True, + action_dim=32, + action_horizon=16, + ), + data=RLDSDroidDataConfig( + repo_id="droid", + # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). + rlds_data_dir="/mnt/pi-data/kevin", + action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, + assets=AssetsConfig( + assets_dir="gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/assets/", + asset_id="droid", + ), + ), + weight_loader='/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2', + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=1_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + num_train_steps=100_000, + batch_size=256, + log_interval=100, + save_interval=5000, + keep_period=10_000, + num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally + ), TrainConfig( # This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset. # Here, we use LeRobot data format (like for all other fine-tuning examples) From 6775bdd0a007e81e40516c2514a5551f12f96ecf Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 27 Aug 2025 11:02:40 -0700 Subject: [PATCH 19/32] reuse jax dataloader --- scripts/train_pytorch.py | 67 +++++++---------- src/openpi/models/model.py | 5 +- src/openpi/training/data_loader.py | 115 ++++++++++++++++++++++------- 3 files changed, 117 insertions(+), 70 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 44cbdcd..904dc72 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -127,32 +127,12 @@ def set_seed(seed: int, local_rank: int): def build_datasets(config: _config.TrainConfig): - # Reuse existing dataset + transforms pipeline - data_conf = config.data.create(config.assets_dirs, config.model) - dataset = _data.create_torch_dataset(data_conf, config.model.action_horizon, config.model) - print(f"data_conf: {data_conf}") - dataset = _data.transform_dataset(dataset, data_conf) - return dataset, data_conf + # Use the unified data loader with PyTorch framework + data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) + return data_loader, data_loader.data_config() -def collate_to_numpy(batch_list: list[Dict[str, Any]]) -> Dict[str, Any]: - # Recursively stack leaves with numpy - def stack_leaf(*xs): - return np.stack([np.asarray(x) for x in xs], axis=0) - # Memory-efficient collation - result = torch.utils.data.default_collate(batch_list) if not isinstance(batch_list[0], dict) else _tree_map_multi(stack_leaf, batch_list) - - return result - - -def _tree_map_multi(func, batch_list): - # batch_list is a list of dicts with same structure; reduce by zipping leaves - def recurse(keys, items): - if isinstance(items[0], dict): - return {k: recurse(keys + [k], [it[k] for it in items]) for k in items[0].keys()} - return func(*items) - return recurse([], batch_list) def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: @@ -383,22 +363,19 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien if is_main: init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) - # Build dataset + sampler + loader - dataset, data_conf = build_datasets(config) - sampler = None - if use_ddp: - sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), shuffle=True, drop_last=True) - - # Use full batch size since we removed gradient accumulation - effective_batch_size = config.batch_size // (torch.distributed.get_world_size() if use_ddp else 1) - - loader = torch.utils.data.DataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_to_numpy) + # Build data loader using the unified data loader + data_loader, data_conf = build_datasets(config) + loader = data_loader # Log sample images to wandb on first batch if is_main and config.wandb_enabled and not resuming: - # Create a separate iterator for sample batch to avoid consuming the main loader - sample_loader = torch.utils.data.DataLoader(dataset, batch_size=effective_batch_size, shuffle=False, sampler=sampler, num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_to_numpy) - sample_batch = next(iter(sample_loader)) + # Create a separate data loader for sample batch to avoid consuming the main loader + sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) + sample_batch = next(iter(sample_data_loader)) + # Convert observation and actions to torch tensors + observation, actions = sample_batch + sample_batch = observation.to_dict() + sample_batch["actions"] = actions sample_batch = batch_to_torch(sample_batch, device) # Create sample images for wandb @@ -513,24 +490,30 @@ def lr_schedule(step: int): pbar = tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None while global_step < config.num_train_steps: - if use_ddp: - sampler.set_epoch(global_step // len(loader)) + # Set epoch for distributed training + if use_ddp and hasattr(loader, 'set_epoch'): + loader.set_epoch(global_step // len(loader)) for batch in loader: # Check if we've reached the target number of steps if global_step >= config.num_train_steps: break - # Convert dict batch directly to torch tensors (bypass Observation.from_dict for PyTorch) - batch = batch_to_torch(batch, device) - actions = batch["actions"] + # The unified data loader returns (observation, actions) tuple + observation, actions = batch + + # Convert observation and actions to torch tensors + observation_dict = observation.to_dict() + observation_dict["actions"] = actions + batch_torch = batch_to_torch(observation_dict, device) + actions = batch_torch["actions"] # Update LR for pg in optim.param_groups: pg["lr"] = lr_schedule(global_step) # Forward pass - observation = _model.Observation.from_dict(batch) + observation = _model.Observation.from_dict(batch_torch) losses = model(observation, actions) # Ensure losses is a tensor and handle different return types if isinstance(losses, (list, tuple)): diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index 8784da5..9f463d9 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -4,7 +4,7 @@ import enum import logging import pathlib -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Union import augmax from flax import nnx @@ -24,7 +24,8 @@ logger = logging.getLogger("openpi") -ArrayT = TypeVar("ArrayT", at.Array, jax.ShapeDtypeStruct) +# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays) +ArrayT = TypeVar("ArrayT", bound=Union[jax.Array, torch.Tensor, np.ndarray]) class ModelType(enum.Enum): diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index 0832861..eef81a1 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -226,8 +226,18 @@ def create_data_loader( shuffle: bool = False, num_batches: int | None = None, skip_norm_stats: bool = False, + framework: str = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: - """Create a data loader for training.""" + """Create a data loader for training. + + Args: + config: The training configuration. + sharding: The sharding to use for the data loader (JAX only). + shuffle: Whether to shuffle the data. + num_batches: Determines the number of batches to return. + skip_norm_stats: Whether to skip data normalization. + framework: The framework to use ("jax" or "pytorch"). + """ data_config = config.data.create(config.assets_dirs, config.model) logging.info(f"data_config: {data_config}") @@ -240,6 +250,7 @@ def create_data_loader( shuffle=shuffle, num_batches=num_batches, skip_norm_stats=skip_norm_stats, + framework=framework, ) return create_torch_data_loader( data_config, @@ -252,6 +263,7 @@ def create_data_loader( num_workers=config.num_workers, seed=config.seed, skip_norm_stats=skip_norm_stats, + framework=framework, ) @@ -267,6 +279,7 @@ def create_torch_data_loader( num_batches: int | None = None, num_workers: int = 0, seed: int = 0, + framework: str = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. @@ -288,17 +301,20 @@ def create_torch_data_loader( dataset = create_torch_dataset(data_config, action_horizon, model_config) dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) + # Use TorchDataLoader for both frameworks with different configurations + local_batch_size = batch_size if framework == "pytorch" else batch_size // jax.process_count() data_loader = TorchDataLoader( dataset, - local_batch_size=batch_size // jax.process_count(), - sharding=sharding, + local_batch_size=local_batch_size, + sharding=None if framework == "pytorch" else sharding, shuffle=shuffle, num_batches=num_batches, num_workers=num_workers, seed=seed, + framework=framework, ) - return DataLoaderImpl(data_config, data_loader) + return DataLoaderImpl(data_config, data_loader, framework=framework) def create_rlds_data_loader( @@ -310,6 +326,7 @@ def create_rlds_data_loader( skip_norm_stats: bool = False, shuffle: bool = False, num_batches: int | None = None, + framework: str = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create an RLDS data loader for training. @@ -327,19 +344,24 @@ def create_rlds_data_loader( number of batches in the dataset, the data loader will loop over the dataset. If not provided, will iterate over the dataset indefinitely. """ - dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) - dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) - - data_loader = RLDSDataLoader( - dataset, - sharding=sharding, - num_batches=num_batches, - ) + if framework == "pytorch": + raise NotImplementedError("PyTorch RLDS data loader is not supported yet") + else: + dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) + dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) + + data_loader = RLDSDataLoader( + dataset, + sharding=sharding, + num_batches=num_batches, + ) - return DataLoaderImpl(data_config, data_loader) + return DataLoaderImpl(data_config, data_loader, framework=framework) class TorchDataLoader: + """Torch data loader implementation.""" + def __init__( self, dataset, @@ -350,6 +372,7 @@ def __init__( num_batches: int | None = None, num_workers: int = 0, seed: int = 0, + framework: str = "jax", ): """Create a PyTorch data loader. @@ -372,14 +395,14 @@ def __init__( if len(dataset) < local_batch_size: raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") - if sharding is None: - # Use data parallel sharding by default. - sharding = jax.sharding.NamedSharding( + # Store sharding - None for PyTorch, JAX sharding for JAX + self._sharding = sharding + if sharding is None and framework == "jax": + # Use data parallel sharding by default for JAX only. + self._sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ("B",)), jax.sharding.PartitionSpec("B"), ) - - self._sharding = sharding self._num_batches = num_batches mp_context = None @@ -388,6 +411,12 @@ def __init__( generator = torch.Generator() generator.manual_seed(seed) + # Choose collate function based on framework + if framework == "jax": + collate_fn = _collate_fn_jax + else: + collate_fn = _collate_fn_pytorch + self._data_loader = torch.utils.data.DataLoader( typing.cast(torch.utils.data.Dataset, dataset), batch_size=local_batch_size, @@ -395,7 +424,7 @@ def __init__( num_workers=num_workers, multiprocessing_context=mp_context, persistent_workers=num_workers > 0, - collate_fn=_collate_fn, + collate_fn=collate_fn, worker_init_fn=_worker_init_fn, drop_last=True, generator=generator, @@ -417,16 +446,38 @@ def __iter__(self): except StopIteration: break # We've exhausted the dataset. Create a new iterator and start over. num_items += 1 - yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) + # For JAX, convert to sharded arrays; for PyTorch, return as-is + if hasattr(self, '_sharding') and self._sharding is not None: + yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) + else: + yield batch -def _collate_fn(items): - """Collate the batch elements into batched numpy arrays.""" - # Make sure to convert to numpy arrays before stacking since some of the incoming elements - # may be JAX arrays. +def _collate_fn_jax(items): + """Collate function for JAX.""" return jax.tree.map(lambda *x: np.stack(np.asarray(x), axis=0), *items) +def _collate_fn_pytorch(items): + """Collate function for PyTorch.""" + def stack_leaf(*xs): + return np.stack([np.asarray(x) for x in xs], axis=0) + + if not isinstance(items[0], dict): + import torch + return torch.utils.data.default_collate(items) + + def _tree_map_multi(func, batch_list): + # batch_list is a list of dicts with same structure; reduce by zipping leaves + def recurse(keys, items): + if isinstance(items[0], dict): + return {k: recurse(keys + [k], [it[k] for it in items]) for k in items[0].keys()} + return func(*items) + return recurse([], batch_list) + + return _tree_map_multi(stack_leaf, items) + + def _worker_init_fn(worker_id: int) -> None: """Tell JAX inside the worker process not to preallocate the GPU memory.""" # NOTE: This is called after jax is imported inside the worker process. This @@ -480,13 +531,25 @@ def __iter__(self): class DataLoaderImpl(DataLoader): - def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader): + def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader, framework: str = "jax"): self._data_config = data_config self._data_loader = data_loader + self._framework = framework def data_config(self) -> _config.DataConfig: return self._data_config def __iter__(self): for batch in self._data_loader: - yield _model.Observation.from_dict(batch), batch["actions"] + if self._framework == "pytorch": + # For PyTorch, convert to torch tensors + import torch + observation_dict = batch.copy() + actions = observation_dict.pop("actions") + # Convert numpy arrays to torch tensors + observation_dict = {k: torch.from_numpy(v) if hasattr(v, 'numpy') else v for k, v in observation_dict.items()} + actions = torch.from_numpy(actions) if hasattr(actions, 'numpy') else actions + yield _model.Observation.from_dict(observation_dict), actions + else: + # For JAX, return as-is (numpy arrays are now accepted by the type annotations) + yield _model.Observation.from_dict(batch), batch["actions"] From 5f3aaba7d19ee77774042c2863999f24fc7306d9 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 27 Aug 2025 12:53:21 -0700 Subject: [PATCH 20/32] further simplify dataloader --- scripts/train_pytorch.py | 5 ++- src/openpi/training/data_loader.py | 61 ++++++------------------------ 2 files changed, 15 insertions(+), 51 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 904dc72..66bf0f1 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -384,7 +384,8 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien batch_size = next(iter(sample_batch['image'].values())).shape[0] for i in range(min(5, batch_size)): # Concatenate all camera views horizontally for this batch item - img_concatenated = torch.cat([img[i] for img in sample_batch['image'].values()], axis=1) + # Convert from NCHW to NHWC format for wandb + img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch['image'].values()], axis=1) img_concatenated = img_concatenated.cpu().numpy() images_to_log.append(wandb.Image(img_concatenated)) @@ -479,7 +480,7 @@ def lr_schedule(step: int): infos = [] # Collect stats over log interval if is_main: logging.info(f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}") - logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}") + logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={batch_size}, num_train_steps={config.num_train_steps}") logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") logging.info(f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}") logging.info(f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}") diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index eef81a1..1cbb6c6 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -1,4 +1,5 @@ from collections.abc import Iterator, Sequence +from typing import Literal import multiprocessing import os import typing @@ -226,7 +227,7 @@ def create_data_loader( shuffle: bool = False, num_batches: int | None = None, skip_norm_stats: bool = False, - framework: str = "jax", + framework: Literal["jax", "pytorch"], ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. @@ -314,7 +315,7 @@ def create_torch_data_loader( framework=framework, ) - return DataLoaderImpl(data_config, data_loader, framework=framework) + return DataLoaderImpl(data_config, data_loader) def create_rlds_data_loader( @@ -356,7 +357,7 @@ def create_rlds_data_loader( num_batches=num_batches, ) - return DataLoaderImpl(data_config, data_loader, framework=framework) + return DataLoaderImpl(data_config, data_loader) class TorchDataLoader: @@ -411,12 +412,6 @@ def __init__( generator = torch.Generator() generator.manual_seed(seed) - # Choose collate function based on framework - if framework == "jax": - collate_fn = _collate_fn_jax - else: - collate_fn = _collate_fn_pytorch - self._data_loader = torch.utils.data.DataLoader( typing.cast(torch.utils.data.Dataset, dataset), batch_size=local_batch_size, @@ -424,7 +419,7 @@ def __init__( num_workers=num_workers, multiprocessing_context=mp_context, persistent_workers=num_workers > 0, - collate_fn=collate_fn, + collate_fn=_collate_fn, worker_init_fn=_worker_init_fn, drop_last=True, generator=generator, @@ -446,36 +441,16 @@ def __iter__(self): except StopIteration: break # We've exhausted the dataset. Create a new iterator and start over. num_items += 1 - # For JAX, convert to sharded arrays; for PyTorch, return as-is - if hasattr(self, '_sharding') and self._sharding is not None: + # For JAX, convert to sharded arrays; for PyTorch, return torch tensors + if self._sharding is not None: yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) else: - yield batch + yield jax.tree.map(torch.as_tensor, batch) -def _collate_fn_jax(items): +def _collate_fn(items): """Collate function for JAX.""" - return jax.tree.map(lambda *x: np.stack(np.asarray(x), axis=0), *items) - - -def _collate_fn_pytorch(items): - """Collate function for PyTorch.""" - def stack_leaf(*xs): - return np.stack([np.asarray(x) for x in xs], axis=0) - - if not isinstance(items[0], dict): - import torch - return torch.utils.data.default_collate(items) - - def _tree_map_multi(func, batch_list): - # batch_list is a list of dicts with same structure; reduce by zipping leaves - def recurse(keys, items): - if isinstance(items[0], dict): - return {k: recurse(keys + [k], [it[k] for it in items]) for k in items[0].keys()} - return func(*items) - return recurse([], batch_list) - - return _tree_map_multi(stack_leaf, items) + return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items) def _worker_init_fn(worker_id: int) -> None: @@ -531,25 +506,13 @@ def __iter__(self): class DataLoaderImpl(DataLoader): - def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader, framework: str = "jax"): + def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader): self._data_config = data_config self._data_loader = data_loader - self._framework = framework def data_config(self) -> _config.DataConfig: return self._data_config def __iter__(self): for batch in self._data_loader: - if self._framework == "pytorch": - # For PyTorch, convert to torch tensors - import torch - observation_dict = batch.copy() - actions = observation_dict.pop("actions") - # Convert numpy arrays to torch tensors - observation_dict = {k: torch.from_numpy(v) if hasattr(v, 'numpy') else v for k, v in observation_dict.items()} - actions = torch.from_numpy(actions) if hasattr(actions, 'numpy') else actions - yield _model.Observation.from_dict(observation_dict), actions - else: - # For JAX, return as-is (numpy arrays are now accepted by the type annotations) - yield _model.Observation.from_dict(batch), batch["actions"] + yield _model.Observation.from_dict(batch), batch["actions"] From bde41bec9e71731667da84840feb53476dfdb0b5 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 27 Aug 2025 13:17:04 -0700 Subject: [PATCH 21/32] fix batch size --- scripts/train_pytorch.py | 6 +++++- src/openpi/training/data_loader.py | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 66bf0f1..4483ce9 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -364,8 +364,12 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) # Build data loader using the unified data loader + # Calculate effective batch size per GPU for DDP + effective_batch_size = config.batch_size // (torch.distributed.get_world_size() if use_ddp else 1) + config.batch_size = effective_batch_size # Update config for data loader data_loader, data_conf = build_datasets(config) loader = data_loader + logging.info(f"Using batch size per GPU: {effective_batch_size} (total batch size: {effective_batch_size * (torch.distributed.get_world_size() if use_ddp else 1)})") # Log sample images to wandb on first batch if is_main and config.wandb_enabled and not resuming: @@ -480,7 +484,7 @@ def lr_schedule(step: int): infos = [] # Collect stats over log interval if is_main: logging.info(f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}") - logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={batch_size}, num_train_steps={config.num_train_steps}") + logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}") logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") logging.info(f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}") logging.info(f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}") diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index 1cbb6c6..addfd74 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -302,7 +302,9 @@ def create_torch_data_loader( dataset = create_torch_dataset(data_config, action_horizon, model_config) dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) - # Use TorchDataLoader for both frameworks with different configurations + # Use TorchDataLoader for both frameworks + # For PyTorch, batch_size is already per-GPU (calculated in train_pytorch.py) + # For JAX, we need to divide by process count local_batch_size = batch_size if framework == "pytorch" else batch_size // jax.process_count() data_loader = TorchDataLoader( dataset, From 974c45cb09ef10b1e45f4a16dc498646f512dc4f Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 27 Aug 2025 13:20:48 -0700 Subject: [PATCH 22/32] fix batch size again --- scripts/train_pytorch.py | 13 ++++++------- src/openpi/training/data_loader.py | 5 +++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 4483ce9..8de711b 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -126,9 +126,10 @@ def set_seed(seed: int, local_rank: int): torch.cuda.manual_seed_all(seed + local_rank) -def build_datasets(config: _config.TrainConfig): +def build_datasets(config: _config.TrainConfig, effective_batch_size: int | None = None): # Use the unified data loader with PyTorch framework - data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) + batch_size = effective_batch_size if effective_batch_size is not None else config.batch_size + data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True, batch_size_override=batch_size) return data_loader, data_loader.data_config() @@ -366,15 +367,13 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien # Build data loader using the unified data loader # Calculate effective batch size per GPU for DDP effective_batch_size = config.batch_size // (torch.distributed.get_world_size() if use_ddp else 1) - config.batch_size = effective_batch_size # Update config for data loader - data_loader, data_conf = build_datasets(config) - loader = data_loader - logging.info(f"Using batch size per GPU: {effective_batch_size} (total batch size: {effective_batch_size * (torch.distributed.get_world_size() if use_ddp else 1)})") + logging.info(f"Using batch size per GPU: {effective_batch_size} (total batch size: {config.batch_size})") + data_loader, data_conf = build_datasets(config, effective_batch_size) # Log sample images to wandb on first batch if is_main and config.wandb_enabled and not resuming: # Create a separate data loader for sample batch to avoid consuming the main loader - sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) + sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False, batch_size_override=effective_batch_size) sample_batch = next(iter(sample_data_loader)) # Convert observation and actions to torch tensors observation, actions = sample_batch diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index addfd74..f641558 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -228,6 +228,7 @@ def create_data_loader( num_batches: int | None = None, skip_norm_stats: bool = False, framework: Literal["jax", "pytorch"], + batch_size_override: int | None = None, ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. @@ -246,7 +247,7 @@ def create_data_loader( return create_rlds_data_loader( data_config, action_horizon=config.model.action_horizon, - batch_size=config.batch_size, + batch_size=batch_size_override if batch_size_override is not None else config.batch_size, sharding=sharding, shuffle=shuffle, num_batches=num_batches, @@ -257,7 +258,7 @@ def create_data_loader( data_config, model_config=config.model, action_horizon=config.model.action_horizon, - batch_size=config.batch_size, + batch_size=batch_size_override if batch_size_override is not None else config.batch_size, sharding=sharding, shuffle=shuffle, num_batches=num_batches, From a644c955a3ba62d918a3a6af094ae54bfcf5c564 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 27 Aug 2025 16:40:17 -0700 Subject: [PATCH 23/32] further clean up --- examples/convert_jax_model_to_pytorch.py | 9 +- scripts/train_pytorch.py | 51 +++--- scripts/train_single_example.py | 4 +- src/openpi/models/model.py | 156 ------------------ src/openpi/models_pytorch/gemma_pytorch.py | 38 +++-- src/openpi/models_pytorch/pi0_pytorch.py | 10 +- .../models/gemma/modeling_gemma.py | 17 +- .../models/siglip/modeling_siglip.py | 4 +- src/openpi/policies/policy_config.py | 1 + src/openpi/training/config.py | 38 ++--- src/openpi/training/data_loader.py | 37 ++++- 11 files changed, 121 insertions(+), 244 deletions(-) diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index fbcac9b..07365a4 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -568,7 +568,12 @@ def __init__(self): print(f"Warning: Could not load all parameters: {e}") print("Continuing with partial load...") - pi0_model = pi0_model.to(torch.float32) + if precision == "float32": + pi0_model = pi0_model.to(torch.float32) + elif precision == "bfloat16": + pi0_model = pi0_model.to(torch.bfloat16) + else: + raise ValueError(f"Invalid precision: {precision}") # Save the converted model using safetensors os.makedirs(output_path, exist_ok=True) @@ -615,7 +620,7 @@ def main(): parser.add_argument( "--precision", choices=["float32", "bfloat16", "float16"], - default="float32", + default="bfloat16", type=str, help="Precision for model conversion" ) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 8de711b..f0f7b20 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -126,10 +126,9 @@ def set_seed(seed: int, local_rank: int): torch.cuda.manual_seed_all(seed + local_rank) -def build_datasets(config: _config.TrainConfig, effective_batch_size: int | None = None): +def build_datasets(config: _config.TrainConfig): # Use the unified data loader with PyTorch framework - batch_size = effective_batch_size if effective_batch_size is not None else config.batch_size - data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True, batch_size_override=batch_size) + data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) return data_loader, data_loader.data_config() @@ -326,14 +325,14 @@ def log_memory_usage(device, step, phase="unknown"): logging.info(f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}") -def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradient_checkpointing: bool = False): +def train_loop(config: _config.TrainConfig): use_ddp, local_rank, device = setup_ddp() is_main = (not use_ddp) or (dist.get_rank() == 0) set_seed(config.seed, local_rank) # Initialize checkpoint directory and wandb resuming = False - if resume: + if config.resume: # Find checkpoint directory based on experiment name exp_checkpoint_dir = config.checkpoint_dir if exp_checkpoint_dir.exists(): @@ -366,14 +365,18 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien # Build data loader using the unified data loader # Calculate effective batch size per GPU for DDP - effective_batch_size = config.batch_size // (torch.distributed.get_world_size() if use_ddp else 1) - logging.info(f"Using batch size per GPU: {effective_batch_size} (total batch size: {config.batch_size})") - data_loader, data_conf = build_datasets(config, effective_batch_size) + # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size + world_size = torch.distributed.get_world_size() if use_ddp else 1 + effective_batch_size = config.batch_size // world_size + logging.info(f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})") + + # Pass the original batch size to data loader - it will handle DDP splitting internally + loader, _ = build_datasets(config) # Log sample images to wandb on first batch if is_main and config.wandb_enabled and not resuming: # Create a separate data loader for sample batch to avoid consuming the main loader - sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False, batch_size_override=effective_batch_size) + sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) sample_batch = next(iter(sample_data_loader)) # Convert observation and actions to torch tensors observation, actions = sample_batch @@ -401,6 +404,7 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): # Convert dataclass to Pi0Config if needed model_cfg = openpi.models.pi0_config.Pi0Config( + dtype=config.pytorch_training_precision, action_dim=config.model.action_dim, action_horizon=config.model.action_horizon, max_token_len=config.model.max_token_len, @@ -413,9 +417,13 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) - if enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): + if hasattr(model, 'gradient_checkpointing_enable'): + enable_gradient_checkpointing = True model.gradient_checkpointing_enable() logging.info("Enabled gradient checkpointing for memory optimization") + else: + enable_gradient_checkpointing = False + logging.info("Gradient checkpointing is not supported for this model") # Log initial memory usage after model creation if is_main and torch.cuda.is_available(): @@ -426,13 +434,12 @@ def train_loop(config: _config.TrainConfig, resume: bool = False, enable_gradien model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=True) # Load weights from weight_loader if specified (for fine-tuning) - if isinstance(config.weight_loader, str): - weight_path = config.weight_loader - logging.info(f"Loading weights from: {weight_path}") + if config.pytorch_weight_path is not None: + logging.info(f"Loading weights from: {config.pytorch_weight_path}") - model_path = os.path.join(weight_path, "model.safetensors") + model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") safetensors.torch.load_model((model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path) - logging.info(f"Loaded PyTorch weights from {weight_path}") + logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") # Optimizer + learning rate schedule from config warmup_steps = config.lr_schedule.warmup_steps @@ -625,19 +632,7 @@ def lr_schedule(step: int): def main(): init_logging() config = _config.cli() - - # Parse additional command line arguments for memory optimization - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("--resume", action="store_true", default=False, - help="Resume training from the latest checkpoint for the experiment (handles both PyTorch and JAX checkpoints)") - - parser.add_argument("--enable_gradient_checkpointing", action="store_true", default=True, - help="Enable gradient checkpointing for memory optimization") - args, _ = parser.parse_known_args() - - train_loop(config, - resume=args.resume, - enable_gradient_checkpointing=args.enable_gradient_checkpointing) + train_loop(config) if __name__ == "__main__": diff --git a/scripts/train_single_example.py b/scripts/train_single_example.py index 80dc712..b919ff7 100644 --- a/scripts/train_single_example.py +++ b/scripts/train_single_example.py @@ -284,8 +284,8 @@ def lr_schedule(step: int): model.train() # Set to training mode for backward pass try: - # Use mock to disable preprocessing - with patch('openpi.models.model.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): + # Mock preprocess_observation_pytorch to avoid augmentations + with patch('openpi.models_pytorch.preprocessing_pytorch.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): # First forward pass print("šŸ”„ First forward pass...") losses_1 = model(observation_torch, actions_torch, noise=noise_tensor, time=time_tensor) diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index 9f463d9..e5b879b 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -209,162 +209,6 @@ def preprocess_observation( ) -def preprocess_observation_pytorch( - observation, - *, - train: bool = False, - image_keys: Sequence[str] = IMAGE_KEYS, - image_resolution: tuple[int, int] = IMAGE_RESOLUTION, -): - """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. - - This function avoids complex type annotations that can cause torch.compile issues. - """ - if not set(image_keys).issubset(observation.images): - raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") - - batch_shape = observation.state.shape[:-1] - - out_images = {} - for key in image_keys: - image = observation.images[key] - - # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats - # Handle both [B, C, H, W] and [B, H, W, C] formats - is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 - - if is_channels_first: - # Convert [B, C, H, W] to [B, H, W, C] for processing - image = image.permute(0, 2, 3, 1) - - if image.shape[1:3] != image_resolution: - logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") - image = image_tools.resize_with_pad_torch(image, *image_resolution) - - if train: - # Convert from [-1, 1] to [0, 1] for PyTorch augmentations - image = image / 2.0 + 0.5 - - # Apply PyTorch-based augmentations - if "wrist" not in key: - # Geometric augmentations for non-wrist cameras - height, width = image.shape[1:3] - - # Random crop and resize - crop_height = int(height * 0.95) - crop_width = int(width * 0.95) - - # Random crop - max_h = height - crop_height - max_w = width - crop_width - if max_h > 0 and max_w > 0: - # Use tensor operations instead of .item() for torch.compile compatibility - start_h = torch.randint(0, max_h + 1, (1,), device=image.device) - start_w = torch.randint(0, max_w + 1, (1,), device=image.device) - image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] - - # Resize back to original size - image = torch.nn.functional.interpolate( - image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - size=(height, width), - mode='bilinear', - align_corners=False - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Random rotation (small angles) - # Use tensor operations instead of .item() for torch.compile compatibility - angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees - if torch.abs(angle) > 0.1: # Only rotate if angle is significant - # Convert to radians - angle_rad = angle * torch.pi / 180.0 - - # Create rotation matrix - cos_a = torch.cos(angle_rad) - sin_a = torch.sin(angle_rad) - - # Apply rotation using grid_sample - grid_x = torch.linspace(-1, 1, width, device=image.device) - grid_y = torch.linspace(-1, 1, height, device=image.device) - - # Create meshgrid - grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') - - # Expand to batch dimension - grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) - grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) - - # Apply rotation transformation - grid_x_rot = grid_x * cos_a - grid_y * sin_a - grid_y_rot = grid_x * sin_a + grid_y * cos_a - - # Stack and reshape for grid_sample - grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) - - image = torch.nn.functional.grid_sample( - image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - grid, - mode='bilinear', - padding_mode='zeros', - align_corners=False - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Color augmentations for all cameras - # Random brightness - # Use tensor operations instead of .item() for torch.compile compatibility - brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 - image = image * brightness_factor - - # Random contrast - # Use tensor operations instead of .item() for torch.compile compatibility - contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 - mean = image.mean(dim=[1, 2, 3], keepdim=True) - image = (image - mean) * contrast_factor + mean - - # Random saturation (convert to HSV, modify S, convert back) - # For simplicity, we'll just apply a random scaling to the color channels - # Use tensor operations instead of .item() for torch.compile compatibility - saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 - gray = image.mean(dim=-1, keepdim=True) - image = gray + (image - gray) * saturation_factor - - # Clamp values to [0, 1] - image = torch.clamp(image, 0, 1) - - # Back to [-1, 1] - image = image * 2.0 - 1.0 - - # Convert back to [B, C, H, W] format if it was originally channels-first - if is_channels_first: - image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] - - out_images[key] = image - - # obtain mask - out_masks = {} - for key in out_images: - if key not in observation.image_masks: - # do not mask by default - out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) - else: - out_masks[key] = observation.image_masks[key] - - # Create a simple object with the required attributes instead of using the complex Observation class - class SimpleProcessedObservation: - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - return SimpleProcessedObservation( - images=out_images, - image_masks=out_masks, - state=observation.state, - tokenized_prompt=observation.tokenized_prompt, - tokenized_prompt_mask=observation.tokenized_prompt_mask, - token_ar_mask=observation.token_ar_mask, - token_loss_mask=observation.token_loss_mask, - ) - - @dataclasses.dataclass(frozen=True) class BaseModelConfig(abc.ABC): """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index c91ed21..9b4fd21 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -5,10 +5,11 @@ from transformers.models.gemma import modeling_gemma from transformers.models.auto import CONFIG_MAPPING +from typing import Literal class PaliGemmaWithExpertModel(nn.Module): - def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): + def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False], precision: Literal["bfloat16", "float32"] = "bfloat16"): super().__init__() vlm_config_hf = CONFIG_MAPPING["paligemma"]() @@ -50,21 +51,26 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): self.to_bfloat16_for_selected_params() - def to_bfloat16_for_selected_params(self): - self = self.to(dtype=torch.float32) + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self = self.to(dtype=torch.bfloat16) + elif precision == "float32": + self = self.to(dtype=torch.float32) + else: + raise ValueError(f"Invalid precision: {precision}") - # params_to_keep_float32 = [ - # "vision_tower.vision_model.embeddings.patch_embedding.weight", - # "vision_tower.vision_model.embeddings.patch_embedding.bias", - # "vision_tower.vision_model.embeddings.position_embedding.weight", - # "input_layernorm", - # "post_attention_layernorm", - # "model.norm", - # ] + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] - # for name, param in self.named_parameters(): - # if any(selector in name for selector in params_to_keep_float32): - # param.data = param.data.to(dtype=torch.float32) + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) def embed_image(self, image: torch.Tensor): return self.paligemma.model.get_image_features(image) @@ -194,7 +200,9 @@ def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_id out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) after_first_residual = out_emb.clone() out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) - #out_emb = out_emb.to(dtype=torch.bfloat16) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 7928d66..fa2c2b9 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -8,7 +8,7 @@ import openpi.models.gemma as _gemma from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel -import openpi.models.model as _model +import openpi.models_pytorch.preprocessing_pytorch as _preprocessing def get_safe_dtype(target_dtype, device_type): @@ -93,7 +93,7 @@ def __init__(self, config): paligemma_config = _gemma.get_config(config.paligemma_variant) action_expert_config = _gemma.get_config(config.action_expert_variant) - self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_config, action_expert_config, use_adarms=[False, True] if self.pi05 else [False, False]) + self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_config, action_expert_config, use_adarms=[False, True] if self.pi05 else [False, False], precision=config.dtype) self.action_in_proj = nn.Linear(32, action_expert_config.width) self.action_out_proj = nn.Linear(action_expert_config.width, 32) @@ -150,7 +150,7 @@ def _prepare_attention_masks_4d(self, att_2d_masks): def _preprocess_observation(self, observation, train=True): """Helper method to preprocess observation.""" - observation = _model.preprocess_observation_pytorch(observation, train=train) + observation = _preprocessing.preprocess_observation_pytorch(observation, train=train) return ( list(observation.images.values()), list(observation.image_masks.values()), @@ -320,7 +320,9 @@ def forward(self, observation, actions, noise=None, time=None) -> Tensor: prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) - #suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) diff --git a/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py b/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py index 434c3fa..34e8422 100644 --- a/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py @@ -457,6 +457,10 @@ def forward( adarms_cond: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: + """ + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -498,7 +502,9 @@ def forward( # embed positions hidden_states = inputs_embeds - #hidden_states = hidden_states.to(torch.bfloat16) + # Convert to bfloat16 if the first layer uses bfloat16 + if len(self.layers) > 0 and self.layers[0].input_layernorm.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -609,6 +615,9 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + Example: ```python @@ -713,6 +722,9 @@ def forward( Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. """ transformer_outputs: BaseModelOutputWithPast = self.model( @@ -809,6 +821,9 @@ def forward( Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. """ outputs: BaseModelOutputWithPast = self.model( diff --git a/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py b/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py index 077571e..3ea8acd 100644 --- a/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py @@ -773,7 +773,9 @@ def forward( ) hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - #hidden_states = hidden_states.to(torch.bfloat16) + # Convert to bfloat16 if the encoder uses bfloat16 + if len(self.encoder.layers) > 0 and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, diff --git a/src/openpi/policies/policy_config.py b/src/openpi/policies/policy_config.py index f0601f4..43d1e87 100644 --- a/src/openpi/policies/policy_config.py +++ b/src/openpi/policies/policy_config.py @@ -52,6 +52,7 @@ def create_trained_policy( logging.info("Loading model...") if is_pytorch: model = train_config.model.load_pytorch(train_config, weight_path) + model.paligemma_with_expert.to_bfloat16_for_selected_params(train_config.pytorch_inference_precision) else: model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16)) data_config = train_config.data.create(train_config.assets_dirs, train_config.model) diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 7a9c2c4..9edd044 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -6,7 +6,7 @@ import difflib import logging import pathlib -from typing import Any, Protocol, TypeAlias +from typing import Any, Protocol, TypeAlias, Literal import etils.epath as epath import flax.nnx as nnx @@ -461,6 +461,14 @@ class TrainConfig: # A weight loader can optionally load (possibly partial) weights from disk after the model is initialized. weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader) + # Optional path to a PyTorch checkpoint to load weights from. + pytorch_weight_path: str | None = None + + # Precision for PyTorch training. + pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16" + # Precision for PyTorch inference. + pytorch_inference_precision: Literal["bfloat16", "float32"] = "bfloat16" + lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule) optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW) ema_decay: float | None = 0.99 @@ -734,33 +742,9 @@ def __post_init__(self) -> None: optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader=weight_loaders.CheckpointWeightLoader( - "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base" - ), - num_train_steps=30_000, - ), - TrainConfig( - name="pi05_libero_pytorch", - model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False), - data=LeRobotLiberoDataConfig( - repo_id="physical-intelligence/libero", - base_config=DataConfig(prompt_from_task=True), - extra_delta_transform=False, + "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" ), - batch_size=256, # 2 nodes, 16 H100s - lr_schedule=_optimizer.CosineDecaySchedule( - warmup_steps=10_000, - peak_lr=5e-5, - decay_steps=1_000_000, - decay_lr=5e-5, - ), - optimizer=_optimizer.AdamW(clip_gradient_norm=1000.0), - ema_decay=0.999, - #weight_loader="/path/to/pi05_libero_pytorch", - # weight_loader=weight_loaders.CheckpointWeightLoader( - # "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" - # ), - #weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2", - weight_loader="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch_float32", + pytorch_weight_path="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch_float32", num_train_steps=30_000, ), # diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index f641558..b9c6c52 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -228,7 +228,6 @@ def create_data_loader( num_batches: int | None = None, skip_norm_stats: bool = False, framework: Literal["jax", "pytorch"], - batch_size_override: int | None = None, ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. @@ -247,7 +246,7 @@ def create_data_loader( return create_rlds_data_loader( data_config, action_horizon=config.model.action_horizon, - batch_size=batch_size_override if batch_size_override is not None else config.batch_size, + batch_size=config.batch_size, sharding=sharding, shuffle=shuffle, num_batches=num_batches, @@ -258,7 +257,7 @@ def create_data_loader( data_config, model_config=config.model, action_horizon=config.model.action_horizon, - batch_size=batch_size_override if batch_size_override is not None else config.batch_size, + batch_size=config.batch_size, sharding=sharding, shuffle=shuffle, num_batches=num_batches, @@ -304,14 +303,34 @@ def create_torch_data_loader( dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) # Use TorchDataLoader for both frameworks - # For PyTorch, batch_size is already per-GPU (calculated in train_pytorch.py) - # For JAX, we need to divide by process count - local_batch_size = batch_size if framework == "pytorch" else batch_size // jax.process_count() + # For PyTorch DDP, create DistributedSampler and divide batch size by world size + # For JAX, divide by process count + sampler = None + if framework == "pytorch": + try: + import torch.distributed as dist + if dist.is_initialized(): + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=shuffle, + drop_last=True, + ) + local_batch_size = batch_size // dist.get_world_size() + else: + local_batch_size = batch_size + except ImportError: + local_batch_size = batch_size + else: + local_batch_size = batch_size // jax.process_count() + data_loader = TorchDataLoader( dataset, local_batch_size=local_batch_size, sharding=None if framework == "pytorch" else sharding, - shuffle=shuffle, + shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler + sampler=sampler, num_batches=num_batches, num_workers=num_workers, seed=seed, @@ -373,6 +392,7 @@ def __init__( *, sharding: jax.sharding.Sharding | None = None, shuffle: bool = False, + sampler: torch.utils.data.Sampler | None = None, num_batches: int | None = None, num_workers: int = 0, seed: int = 0, @@ -418,7 +438,8 @@ def __init__( self._data_loader = torch.utils.data.DataLoader( typing.cast(torch.utils.data.Dataset, dataset), batch_size=local_batch_size, - shuffle=shuffle, + shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler + sampler=sampler, num_workers=num_workers, multiprocessing_context=mp_context, persistent_workers=num_workers > 0, From db8b5f7eb6ab24e8b9d9d05edc1b6c7b058e071a Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 27 Aug 2025 16:40:47 -0700 Subject: [PATCH 24/32] add missing file --- .../models_pytorch/preprocessing_pytorch.py | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 src/openpi/models_pytorch/preprocessing_pytorch.py diff --git a/src/openpi/models_pytorch/preprocessing_pytorch.py b/src/openpi/models_pytorch/preprocessing_pytorch.py new file mode 100644 index 0000000..d67d24b --- /dev/null +++ b/src/openpi/models_pytorch/preprocessing_pytorch.py @@ -0,0 +1,171 @@ +import logging +from collections.abc import Sequence +import torch + +from openpi.shared import image_tools + +logger = logging.getLogger("openpi") + +# Constants moved from model.py +IMAGE_KEYS = ( + "base_0_rgb", + "left_wrist_0_rgb", + "right_wrist_0_rgb", +) + +IMAGE_RESOLUTION = (224, 224) + +def preprocess_observation_pytorch( + observation, + *, + train: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +): + """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. + + This function avoids complex type annotations that can cause torch.compile issues. + """ + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + for key in image_keys: + image = observation.images[key] + + # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats + # Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + image = image.permute(0, 2, 3, 1) + + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad_torch(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + image = image / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = image.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=image.device) + start_w = torch.randint(0, max_w + 1, (1,), device=image.device) + image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] + + # Resize back to original size + image = torch.nn.functional.interpolate( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode='bilinear', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = angle * torch.pi / 180.0 + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=image.device) + grid_y = torch.linspace(-1, 1, height, device=image.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + image = torch.nn.functional.grid_sample( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 + image = image * brightness_factor + + # Random contrast + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 + mean = image.mean(dim=[1, 2, 3], keepdim=True) + image = (image - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 + gray = image.mean(dim=-1, keepdim=True) + image = gray + (image - gray) * saturation_factor + + # Clamp values to [0, 1] + image = torch.clamp(image, 0, 1) + + # Back to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) + else: + out_masks[key] = observation.image_masks[key] + + # Create a simple object with the required attributes instead of using the complex Observation class + class SimpleProcessedObservation: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + return SimpleProcessedObservation( + images=out_images, + image_masks=out_masks, + state=observation.state, + tokenized_prompt=observation.tokenized_prompt, + tokenized_prompt_mask=observation.tokenized_prompt_mask, + token_ar_mask=observation.token_ar_mask, + token_loss_mask=observation.token_loss_mask, + ) From c51b0619ec0a86b6eb38daf6b565e2655d3624a8 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Thu, 28 Aug 2025 03:40:20 -0700 Subject: [PATCH 25/32] add documentation, check transfoemres_replace --- README.md | 73 +++++++++++++++++++++++- src/openpi/models_pytorch/pi0_pytorch.py | 6 ++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7300a78..742000e 100644 --- a/README.md +++ b/README.md @@ -185,10 +185,10 @@ We provide more examples for how to fine-tune and run inference with our models openpi now provides PyTorch implementations of π₀ and π₀.ā‚… models alongside the original JAX versions (π₀-FAST is not currently supported in PyTorch). The PyTorch models offer greater deployment flexibility and seamless integration with PyTorch-based ML stacks, while maintaining feature parity with their JAX counterparts. ### Setup -1. Upgrade the transformers library to 4.53.2 +1. Upgrade the transformers library to 4.53.2 and torch to 2.7.1 - The required version is already specified in pyproject.toml - - If you set up your environment previously, reinstall it to ensure you have transformers 4.53.2 - - You can verify the version with `pip show transformers` + - If you set up your environment previously, reinstall it to ensure you have transformers 4.53.2 and torch 2.7.1 + - You can verify the version with `uv pip show transformers` and `uv pip show torch` 2. Apply the transformers library patches ```bash @@ -235,6 +235,73 @@ uv run scripts/serve_policy.py policy:checkpoint \ --policy.dir=/path/to/converted/pytorch/checkpoint ``` +### Finetuning with PyTorch + +To finetune a model in PyTorch: + +1. Convert the JAX base model to PyTorch format: + ```bash + python examples/convert_jax_model_to_pytorch.py \ + --checkpoint_dir /path/to/jax/base/model \ + --output_path /path/to/pytorch/base/model + ``` + +2. Specify the converted PyTorch model path in your config using `pytorch_weight_path` + +3. Launch training using one of these modes: + +```bash +# Single GPU training: +uv run scripts/train_pytorch.py --exp_name --save_interval + +# Example: +uv run scripts/train_pytorch.py debug --exp_name pytorch_test +uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume # Resume from latest checkpoint + +# Multi-GPU training (single node): +torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name + +# Example: +torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test +torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume + +# Multi-Node Training: +torchrun \ + --nnodes= \ + --nproc_per_node= \ + --node_rank= \ + --master_addr= \ + --master_port= \ + scripts/train_pytorch.py --exp_name= --save_interval +``` + +### Precision Settings + +JAX and PyTorch implementations handle precision as follows: + +**JAX:** +1. Inference: weights and activations are bfloat16 except for selected layers in float32 +2. Training: weights and gradients are float32, activations are bfloat16 + +**PyTorch:** +1. Inference: matches JAX - weights and activations are bfloat16 except for selected layers in float32 +2. Training: supports either all float32 or the same mixed precision as inference (default) + +### Validation Results + +We have validated the PyTorch implementation against JAX: + +1. Converting a JAX-finetuned pi05_libero model (bfloat16 precision) to PyTorch and evaluate: + - JAX checkpoint: 92.0% on Libero-10 + - Converted PyTorch checkpoint: 92.2% on Libero-10 + +2. Finetuning pi05_libero from pi05_base model in PyTorch: + - Using float32: (convert JAX ckpt to pytorch in float32) Loss curves match JAX throughout training + - Using bfloat16: (convert JAX ckpt to pytorch in bfloat16) Loss curves match after 10k steps (in 30k total steps) + - Final performance: 91.4% on Libero-10 + +3. We comapred inference speed between JAX and PyTorch and they are comparable. + ## Troubleshooting We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines). diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index fa2c2b9..dad7aae 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -111,6 +111,12 @@ def __init__(self, config): # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False + try: + from transformers.models.siglip import check + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError("TransformersReplace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`.") + except ImportError: + raise ValueError("TransformersReplace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`.") def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" From ed03aee0e75af82839df9ddc99246ed9d7d394f4 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Thu, 28 Aug 2025 03:42:25 -0700 Subject: [PATCH 26/32] change pi0 back --- src/openpi/models/pi0.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/openpi/models/pi0.py b/src/openpi/models/pi0.py index 7160983..6843f6d 100644 --- a/src/openpi/models/pi0.py +++ b/src/openpi/models/pi0.py @@ -187,24 +187,15 @@ def embed_suffix( @override def compute_loss( - self, - rng: at.KeyArrayLike, - observation: _model.Observation, - actions: _model.Actions, - *, - train: bool = False, - noise: at.Float[at.Array, "*b ah ad"] | None = None, - time: at.Float[at.Array, "*b"] | None = None + self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False ) -> at.Float[at.Array, "*b ah"]: preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) observation = _model.preprocess_observation(preprocess_rng, observation, train=train) batch_shape = actions.shape[:-2] # Use provided noise and time if available, otherwise generate them - if noise is None: - noise = jax.random.normal(noise_rng, actions.shape) - if time is None: - time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 + noise = jax.random.normal(noise_rng, actions.shape) + time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 time_expanded = time[..., None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions @@ -221,7 +212,7 @@ def compute_loss( ) v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) - return jnp.square(v_t - u_t) + return jnp.mean(jnp.square(v_t - u_t), axis=-1) @override def sample_actions( From b42d3ac320875b2fbb35f3cbdc8e2d03b29297da Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Thu, 28 Aug 2025 03:46:51 -0700 Subject: [PATCH 27/32] minor changes --- src/openpi/models/pi0.py | 1 - src/openpi/training/config.py | 58 ------------------------------ src/openpi/training/data_loader.py | 4 ++- 3 files changed, 3 insertions(+), 60 deletions(-) diff --git a/src/openpi/models/pi0.py b/src/openpi/models/pi0.py index 6843f6d..ae7c459 100644 --- a/src/openpi/models/pi0.py +++ b/src/openpi/models/pi0.py @@ -193,7 +193,6 @@ def compute_loss( observation = _model.preprocess_observation(preprocess_rng, observation, train=train) batch_shape = actions.shape[:-2] - # Use provided noise and time if available, otherwise generate them noise = jax.random.normal(noise_rng, actions.shape) time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 time_expanded = time[..., None, None] diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 9edd044..0ffb70f 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -882,64 +882,6 @@ def __post_init__(self) -> None: keep_period=10_000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), - TrainConfig( - # This config is for fine-tuning pi05 on the *full* DROID dataset. - # We use RLDS data loading to make training on this large dataset tractable. - # For fine-tuning on your own DROID dataset, see below. - name="pi05_full_droid_finetune_pytorch", - model=pi0_config.Pi0Config( - pi05=True, - action_dim=32, - action_horizon=16, - ), - data=RLDSDroidDataConfig( - repo_id="droid", - # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). - rlds_data_dir="/mnt/pi-data/kevin", - action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, - assets=AssetsConfig( - assets_dir="gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/assets/", - asset_id="droid", - ), - ), - weight_loader='/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2', - lr_schedule=_optimizer.CosineDecaySchedule( - warmup_steps=1_000, - peak_lr=5e-5, - decay_steps=1_000_000, - decay_lr=5e-5, - ), - num_train_steps=100_000, - batch_size=256, - log_interval=100, - save_interval=5000, - keep_period=10_000, - num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally - ), - TrainConfig( - # This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset. - # Here, we use LeRobot data format (like for all other fine-tuning examples) - # To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py - name="pi05_droid_finetune", - model=pi0_config.Pi0Config( - pi05=True, - action_dim=32, # pi05 is trained with 32-dim actions - action_horizon=16, - ), - data=LeRobotDROIDDataConfig( - # Replace with your custom DROID LeRobot dataset repo id. - repo_id="your_hf_username/my_droid_dataset", - base_config=DataConfig(prompt_from_task=True), - assets=AssetsConfig( - # Important: reuse the original DROID norm stats during fine-tuning! - assets_dir="gs://openpi-assets-preview/checkpoints/pi05_droid/assets", - asset_id="droid", - ), - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets-preview/checkpoints/pi05_droid/params"), - num_train_steps=20_000, - batch_size=32, - ), # # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment. # diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index b9c6c52..920edba 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -473,7 +473,9 @@ def __iter__(self): def _collate_fn(items): - """Collate function for JAX.""" + """Collate the batch elements into batched numpy arrays.""" + # Make sure to convert to numpy arrays before stacking since some of the incoming elements + # may be JAX arrays. return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items) From a30c8d600ca738d9dd4d32378d33840787bfe250 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Thu, 28 Aug 2025 03:51:39 -0700 Subject: [PATCH 28/32] further cleanup --- src/openpi/training/config.py | 34 ++++++++++++++++++++++++++++++ src/openpi/training/data_loader.py | 24 +++++++++------------ 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 0ffb70f..b7c10f7 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -882,6 +882,40 @@ def __post_init__(self) -> None: keep_period=10_000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), + TrainConfig( + # This config is for fine-tuning pi05 on the *full* DROID dataset. + # We use RLDS data loading to make training on this large dataset tractable. + # For fine-tuning on your own DROID dataset, see below. + name="pi05_full_droid_finetune_pytorch", + model=pi0_config.Pi0Config( + pi05=True, + action_dim=32, + action_horizon=16, + ), + data=RLDSDroidDataConfig( + repo_id="droid", + # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). + rlds_data_dir="/mnt/pi-data/kevin", + action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, + assets=AssetsConfig( + assets_dir="gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/assets/", + asset_id="droid", + ), + ), + weight_loader='/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2', + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=1_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + num_train_steps=100_000, + batch_size=256, + log_interval=100, + save_interval=5000, + keep_period=10_000, + num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally + ), # # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment. # diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index 920edba..6c1e916 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -307,20 +307,16 @@ def create_torch_data_loader( # For JAX, divide by process count sampler = None if framework == "pytorch": - try: - import torch.distributed as dist - if dist.is_initialized(): - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=shuffle, - drop_last=True, - ) - local_batch_size = batch_size // dist.get_world_size() - else: - local_batch_size = batch_size - except ImportError: + if torch.distributed.is_initialized(): + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=torch.distributed.get_world_size(), + rank=torch.distributed.get_rank(), + shuffle=shuffle, + drop_last=True, + ) + local_batch_size = batch_size // torch.distributed.get_world_size() + else: local_batch_size = batch_size else: local_batch_size = batch_size // jax.process_count() From babea0e63aeb97ab7d7460ccc7f03243d8d98534 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Thu, 28 Aug 2025 04:12:37 -0700 Subject: [PATCH 29/32] further cleanup --- README.md | 2 +- examples/convert_jax_model_to_pytorch.py | 6 +-- scripts/train_pytorch.py | 69 +----------------------- src/openpi/policies/policy_config.py | 2 +- src/openpi/training/config.py | 2 - 5 files changed, 4 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index 742000e..9126e25 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,7 @@ JAX and PyTorch implementations handle precision as follows: **PyTorch:** 1. Inference: matches JAX - weights and activations are bfloat16 except for selected layers in float32 -2. Training: supports either all float32 or the same mixed precision as inference (default) +2. Training: supports either all float32 or the same mixed precision as inference (default) You can change it by setting `pytorch_training_precision` in the config ### Validation Results diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index 07365a4..739b909 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -562,11 +562,7 @@ def __init__(self): all_params = {**paligemma_params, **gemma_params, **projection_params} # Load state dict - try: - pi0_model.load_state_dict(all_params) - except Exception as e: - print(f"Warning: Could not load all parameters: {e}") - print("Continuing with partial load...") + pi0_model.load_state_dict(all_params, strict=False) if precision == "float32": pi0_model = pi0_model.to(torch.float32) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index f0f7b20..53ac763 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -132,9 +132,6 @@ def build_datasets(config: _config.TrainConfig): return data_loader, data_loader.data_config() - - - def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: # Memory-efficient conversion: convert to torch tensors and move to device in one step batch = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(device), batch) @@ -238,70 +235,6 @@ def get_latest_checkpoint_step(checkpoint_dir): return max(checkpoint_steps) if checkpoint_steps else None -def validate_checkpoint_integrity(checkpoint_dir, step): - """Validate that a checkpoint at the given step is complete and uncorrupted.""" - ckpt_dir = checkpoint_dir / f"{step}" - - required_files = ["pytorch_model.pt", "optimizer.pt", "metadata.pt"] - optional_files = ["ema_model.pt"] - - # Check if all required files exist - for file_name in required_files: - file_path = ckpt_dir / file_name - if not file_path.exists(): - logging.warning(f"Required checkpoint file missing: {file_path}") - return False - - # Try to validate file integrity by attempting to load them - try: - # Test model file - device = torch.device("cpu") # Use CPU for validation to avoid GPU memory issues - model_state = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) - if not isinstance(model_state, dict): - logging.warning(f"Model checkpoint file corrupted at step {step}") - return False - - # Test optimizer file - optimizer_state = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) - if not isinstance(optimizer_state, dict): - logging.warning(f"Optimizer checkpoint file corrupted at step {step}") - return False - - # Test metadata file - metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) - if not isinstance(metadata, dict) or "global_step" not in metadata: - logging.warning(f"Metadata checkpoint file corrupted at step {step}") - return False - - logging.info(f"Checkpoint at step {step} validated successfully") - return True - - except Exception as e: - logging.warning(f"Checkpoint validation failed at step {step}: {e}") - return False - - -def find_latest_valid_checkpoint(checkpoint_dir): - """Find the latest checkpoint that passes integrity validation.""" - checkpoint_steps = [] - for d in checkpoint_dir.iterdir(): - if d.is_dir() and d.name.isdigit(): - checkpoint_steps.append(int(d.name)) - - if not checkpoint_steps: - return None - - # Sort steps in descending order to check latest first - checkpoint_steps.sort(reverse=True) - - for step in checkpoint_steps: - if validate_checkpoint_integrity(checkpoint_dir, step): - return step - - logging.error("No valid checkpoints found in directory") - return None - - def log_memory_usage(device, step, phase="unknown"): """Log detailed memory usage information.""" if not torch.cuda.is_available(): @@ -337,7 +270,7 @@ def train_loop(config: _config.TrainConfig): exp_checkpoint_dir = config.checkpoint_dir if exp_checkpoint_dir.exists(): # Use validation to find the latest working checkpoint - latest_step = find_latest_valid_checkpoint(exp_checkpoint_dir) + latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) if latest_step is not None: resuming = True logging.info(f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}") diff --git a/src/openpi/policies/policy_config.py b/src/openpi/policies/policy_config.py index 43d1e87..3e8a2b7 100644 --- a/src/openpi/policies/policy_config.py +++ b/src/openpi/policies/policy_config.py @@ -52,7 +52,7 @@ def create_trained_policy( logging.info("Loading model...") if is_pytorch: model = train_config.model.load_pytorch(train_config, weight_path) - model.paligemma_with_expert.to_bfloat16_for_selected_params(train_config.pytorch_inference_precision) + model.paligemma_with_expert.to_bfloat16_for_selected_params('bfloat16') else: model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16)) data_config = train_config.data.create(train_config.assets_dirs, train_config.model) diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index b7c10f7..b503cf4 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -466,8 +466,6 @@ class TrainConfig: # Precision for PyTorch training. pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16" - # Precision for PyTorch inference. - pytorch_inference_precision: Literal["bfloat16", "float32"] = "bfloat16" lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule) optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW) From f4a5847917b11beba9e167280a1a135ca5703e82 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Thu, 28 Aug 2025 04:21:28 -0700 Subject: [PATCH 30/32] fix a bug --- .../transformers_replace/models/gemma/modeling_gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py b/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py index 34e8422..d052977 100644 --- a/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py @@ -503,7 +503,7 @@ def forward( # embed positions hidden_states = inputs_embeds # Convert to bfloat16 if the first layer uses bfloat16 - if len(self.layers) > 0 and self.layers[0].input_layernorm.weight.dtype == torch.bfloat16: + if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.bfloat16) # create position embeddings to be shared across the decoder layers From b8597b6e5019355120c5e15cb9cb64926b781678 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Thu, 28 Aug 2025 04:25:14 -0700 Subject: [PATCH 31/32] add check.py --- .../transformers_replace/models/siglip/check.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 src/openpi/models_pytorch/transformers_replace/models/siglip/check.py diff --git a/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py b/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py new file mode 100644 index 0000000..4bb3c96 --- /dev/null +++ b/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py @@ -0,0 +1,4 @@ +import transformers + +def check_whether_transformers_replace_is_installed_correctly(): + return transformers.__version__ == "4.53.2" \ No newline at end of file From cde785653b6ac0548c8356d568c198d17beedba2 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Thu, 28 Aug 2025 09:30:42 -0700 Subject: [PATCH 32/32] fix float32 --- README.md | 2 +- scripts/train_pytorch.py | 108 +++++++++++++++------ src/openpi/models_pytorch/gemma_pytorch.py | 3 +- src/openpi/training/config.py | 3 +- 4 files changed, 81 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 9126e25..5f200fb 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,7 @@ JAX and PyTorch implementations handle precision as follows: **PyTorch:** 1. Inference: matches JAX - weights and activations are bfloat16 except for selected layers in float32 -2. Training: supports either all float32 or the same mixed precision as inference (default) You can change it by setting `pytorch_training_precision` in the config +2. Training: supports either all float32 or the same mixed precision as inference (default) You can change it by setting `pytorch_training_precision` in the config. Per GPU batch size can be 64 with mixed precision but 16 with float32 on a 80GB memory A100 or H100. Further optimizations to reduce memory consumption can be done in the future. ### Validation Results diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 53ac763..6c6fd60 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -191,38 +191,79 @@ def save_checkpoint(model, optimizer, global_step, config, is_main, ema_model=No def load_checkpoint(model, optimizer, checkpoint_dir, device, ema_model=None): - """Load the latest checkpoint and return the global step.""" - checkpoint_steps = [] - for d in checkpoint_dir.iterdir(): - if d.is_dir() and d.name.isdigit(): - checkpoint_steps.append(int(d.name)) - - if not checkpoint_steps: - raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - - latest_step = max(checkpoint_steps) - ckpt_dir = checkpoint_dir / f"{latest_step}" - - # Load model state with error handling - model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) - (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model).load_state_dict(model_state_dict) - logging.info(f"Successfully loaded model state from step {latest_step}") - - # Load optimizer state with error handling and fallback - optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) - optimizer.load_state_dict(optimizer_state_dict) - - # Load EMA state if available - if ema_model is not None and (ckpt_dir / "ema_model.pt").exists(): - ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device, weights_only=False) - ema_model.load_state_dict(ema_state_dict) - logging.info(f"Successfully loaded EMA state from step {latest_step}") - - # Load metadata (weights_only=False needed for older checkpoints that might contain JAX/Flax objects) - metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) - global_step = metadata.get("global_step", latest_step) - logging.info(f"Successfully loaded metadata from step {latest_step}") - return global_step + """Load the latest checkpoint and return the global step.""" + checkpoint_steps = [] + for d in checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + if not checkpoint_steps: + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + + latest_step = max(checkpoint_steps) + ckpt_dir = checkpoint_dir / f"{latest_step}" + + # Clear memory before loading checkpoints + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "before_loading_checkpoint") + + try: + # Load model state with error handling + logging.info("Loading model state...") + model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) + (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model).load_state_dict(model_state_dict) + del model_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_model") + + # Load optimizer state with error handling + logging.info("Loading optimizer state...") + optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_optimizer") + + # Load EMA state if available + if ema_model is not None and (ckpt_dir / "ema_model.pt").exists(): + logging.info("Loading EMA state...") + # Clear as much memory as possible before loading EMA + torch.cuda.empty_cache() + gc.collect() + + ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device, weights_only=False) + ema_model.load_state_dict(ema_state_dict) + del ema_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_ema") + logging.info(f"Successfully loaded EMA state from step {latest_step}") + + # Load metadata + logging.info("Loading metadata...") + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + global_step = metadata.get("global_step", latest_step) + del metadata + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_metadata") + + logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") + return global_step + + except RuntimeError as e: + if "out of memory" in str(e): + # Clear memory and provide detailed error message + torch.cuda.empty_cache() + gc.collect() + logging.error(f"Out of memory error while loading checkpoint: {str(e)}") + log_memory_usage(device, latest_step, "after_oom_error") + raise RuntimeError(f"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True") from e + raise def get_latest_checkpoint_step(checkpoint_dir): @@ -347,6 +388,8 @@ def train_loop(config: _config.TrainConfig): ) else: model_cfg = config.model + # Update dtype to match pytorch_training_precision + object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) @@ -427,6 +470,7 @@ def lr_schedule(step: int): logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") logging.info(f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}") logging.info(f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}") + logging.info(f"Training precision: {model_cfg.dtype}") if config.ema_decay is not None: logging.info(f"EMA decay: {config.ema_decay}") diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 9b4fd21..e6d7f37 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -49,13 +49,14 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False], self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None - self.to_bfloat16_for_selected_params() + self.to_bfloat16_for_selected_params(precision) def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): if precision == "bfloat16": self = self.to(dtype=torch.bfloat16) elif precision == "float32": self = self.to(dtype=torch.float32) + return else: raise ValueError(f"Invalid precision: {precision}") diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index b503cf4..64635a7 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -730,7 +730,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=256, + batch_size=16, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -743,6 +743,7 @@ def __post_init__(self) -> None: "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" ), pytorch_weight_path="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch_float32", + pytorch_training_precision="float32", num_train_steps=30_000, ), #