diff --git a/README.md b/README.md index 7300a78..5f200fb 100644 --- a/README.md +++ b/README.md @@ -185,10 +185,10 @@ We provide more examples for how to fine-tune and run inference with our models openpi now provides PyTorch implementations of π₀ and π₀.₅ models alongside the original JAX versions (π₀-FAST is not currently supported in PyTorch). The PyTorch models offer greater deployment flexibility and seamless integration with PyTorch-based ML stacks, while maintaining feature parity with their JAX counterparts. ### Setup -1. Upgrade the transformers library to 4.53.2 +1. Upgrade the transformers library to 4.53.2 and torch to 2.7.1 - The required version is already specified in pyproject.toml - - If you set up your environment previously, reinstall it to ensure you have transformers 4.53.2 - - You can verify the version with `pip show transformers` + - If you set up your environment previously, reinstall it to ensure you have transformers 4.53.2 and torch 2.7.1 + - You can verify the version with `uv pip show transformers` and `uv pip show torch` 2. Apply the transformers library patches ```bash @@ -235,6 +235,73 @@ uv run scripts/serve_policy.py policy:checkpoint \ --policy.dir=/path/to/converted/pytorch/checkpoint ``` +### Finetuning with PyTorch + +To finetune a model in PyTorch: + +1. Convert the JAX base model to PyTorch format: + ```bash + python examples/convert_jax_model_to_pytorch.py \ + --checkpoint_dir /path/to/jax/base/model \ + --output_path /path/to/pytorch/base/model + ``` + +2. Specify the converted PyTorch model path in your config using `pytorch_weight_path` + +3. Launch training using one of these modes: + +```bash +# Single GPU training: +uv run scripts/train_pytorch.py --exp_name --save_interval + +# Example: +uv run scripts/train_pytorch.py debug --exp_name pytorch_test +uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume # Resume from latest checkpoint + +# Multi-GPU training (single node): +torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name + +# Example: +torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test +torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume + +# Multi-Node Training: +torchrun \ + --nnodes= \ + --nproc_per_node= \ + --node_rank= \ + --master_addr= \ + --master_port= \ + scripts/train_pytorch.py --exp_name= --save_interval +``` + +### Precision Settings + +JAX and PyTorch implementations handle precision as follows: + +**JAX:** +1. Inference: weights and activations are bfloat16 except for selected layers in float32 +2. Training: weights and gradients are float32, activations are bfloat16 + +**PyTorch:** +1. Inference: matches JAX - weights and activations are bfloat16 except for selected layers in float32 +2. Training: supports either all float32 or the same mixed precision as inference (default) You can change it by setting `pytorch_training_precision` in the config. Per GPU batch size can be 64 with mixed precision but 16 with float32 on a 80GB memory A100 or H100. Further optimizations to reduce memory consumption can be done in the future. + +### Validation Results + +We have validated the PyTorch implementation against JAX: + +1. Converting a JAX-finetuned pi05_libero model (bfloat16 precision) to PyTorch and evaluate: + - JAX checkpoint: 92.0% on Libero-10 + - Converted PyTorch checkpoint: 92.2% on Libero-10 + +2. Finetuning pi05_libero from pi05_base model in PyTorch: + - Using float32: (convert JAX ckpt to pytorch in float32) Loss curves match JAX throughout training + - Using bfloat16: (convert JAX ckpt to pytorch in bfloat16) Loss curves match after 10k steps (in 30k total steps) + - Final performance: 91.4% on Libero-10 + +3. We comapred inference speed between JAX and PyTorch and they are comparable. + ## Troubleshooting We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines). diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index 219402d..739b909 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -8,10 +8,12 @@ Usage: # Just inspect keys: - python convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only # Convert to PyTorch: - python convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output + python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output Example: # pi0_droid @@ -336,8 +338,12 @@ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | } restore_dtype = dtype_map.get(restore_precision) if restore_precision else None + # Use CPU sharding to avoid GPU memory issues during checkpoint loading + cpu_device = jax.devices('cpu')[0] + cpu_sharding = jax.sharding.SingleDeviceSharding(cpu_device) + # Use repository restore utility to load a pure dict of params (value suffix removed) - params = openpi.models.model.restore_params(params_dir, restore_type=jax.Array, dtype=restore_dtype) + params = openpi.models.model.restore_params(params_dir, restore_type=jax.Array, dtype=restore_dtype, sharding=cpu_sharding) # get params for PaliGemma pali_params = params["PaliGemma"] @@ -382,7 +388,8 @@ def load_jax_model_and_print_keys(checkpoint_dir: str): return item = {params_name: metadata[params_name]} - device = jax.local_devices()[0] + # Use CPU device to avoid GPU memory issues + device = jax.devices('cpu')[0] sharding = jax.sharding.SingleDeviceSharding(device) restored = checkpointer.restore( @@ -536,6 +543,12 @@ def __init__(self): action_horizon=10, pi05=True, ) + elif "pi05_base" in checkpoint_dir: + pi0_config = openpi.models.pi0_config.Pi0Config( + action_dim=32, + action_horizon=50, + pi05=True, + ) else: pi0_config = openpi.models.pi0_config.Pi0Config( action_dim=8, @@ -549,13 +562,14 @@ def __init__(self): all_params = {**paligemma_params, **gemma_params, **projection_params} # Load state dict - try: - pi0_model.load_state_dict(all_params) - except Exception as e: - print(f"Warning: Could not load all parameters: {e}") - print("Continuing with partial load...") + pi0_model.load_state_dict(all_params, strict=False) - pi0_model = pi0_model.to(torch.bfloat16) + if precision == "float32": + pi0_model = pi0_model.to(torch.float32) + elif precision == "bfloat16": + pi0_model = pi0_model.to(torch.bfloat16) + else: + raise ValueError(f"Invalid precision: {precision}") # Save the converted model using safetensors os.makedirs(output_path, exist_ok=True) @@ -602,7 +616,7 @@ def main(): parser.add_argument( "--precision", choices=["float32", "bfloat16", "float16"], - default="float32", + default="bfloat16", type=str, help="Precision for model conversion" ) diff --git a/examples/inference_example.py b/examples/inference_example.py index 74eaa4d..72e9d76 100644 --- a/examples/inference_example.py +++ b/examples/inference_example.py @@ -5,13 +5,13 @@ This demonstrates the basic usage patterns for both implementations. pi0_droid -python examples/inference_example.py --model_name pi0_droid --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_droid --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch +python examples/inference_example.py --model_name pi0_droid --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch pi0_aloha_sim -python examples/inference_example.py --model_name pi0_aloha_sim --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch +python examples/inference_example.py --model_name pi0_aloha_sim --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch pi05_droid -python examples/inference_example.py --model_name pi05_droid --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch +python examples/inference_example.py --model_name pi05_droid --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch """ import argparse diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py new file mode 100644 index 0000000..6c6fd60 --- /dev/null +++ b/scripts/train_pytorch.py @@ -0,0 +1,616 @@ +""" +PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. +This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs +entirely in PyTorch using the `PI0Pytorch` model and your existing config/data +pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. + +Usage +Single GPU: + python scripts/train_pytorch.py --exp_name --save_interval + Example: + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint +Multi-GPU (single node): + torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name + Example: + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume +Multi-Node Training: + torchrun \ + --nnodes= --nproc_per_node= --node_rank= \ + --master_addr= --master_port= \ + scripts/train_pytorch.py --exp_name= --save_interval + +""" + +import argparse +import dataclasses +import gc +import logging +import os +import platform +import shutil +import time +from typing import Any, Dict + +import jax +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.parallel +import torch.utils.data +import torch.utils.data.distributed +import tqdm +import wandb +import safetensors.torch + +import openpi.training.config as _config +import openpi.training.data_loader as _data +import openpi.models.model as _model +import openpi.models_pytorch.pi0_pytorch +import openpi.models.pi0_config + + +def init_logging(): + level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} + + class CustomFormatter(logging.Formatter): + def format(self, record): + record.levelname = level_mapping.get(record.levelname, record.levelname) + return super().format(record) + + formatter = CustomFormatter( + fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", + datefmt="%H:%M:%S", + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + if not logger.handlers: + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + else: + logger.handlers[0].setFormatter(formatter) + + +def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): + """Initialize wandb logging.""" + if not enabled: + wandb.init(mode="disabled") + return + + ckpt_dir = config.checkpoint_dir + if not ckpt_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") + + if resuming: + run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() + wandb.init(id=run_id, resume="must", project=config.project_name) + else: + wandb.init( + name=config.exp_name, + config=dataclasses.asdict(config), + project=config.project_name, + ) + (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) + + +def setup_ddp(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + use_ddp = world_size > 1 + if use_ddp and not torch.distributed.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + torch.distributed.init_process_group(backend=backend, init_method="env://") + + # Set up debugging environment variables for DDP issues + if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" + + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(device) + return use_ddp, local_rank, device + + +def cleanup_ddp(): + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +def set_seed(seed: int, local_rank: int): + torch.manual_seed(seed + local_rank) + np.random.seed(seed + local_rank) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed + local_rank) + + +def build_datasets(config: _config.TrainConfig): + # Use the unified data loader with PyTorch framework + data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) + return data_loader, data_loader.data_config() + + +def batch_to_torch(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: + # Memory-efficient conversion: convert to torch tensors and move to device in one step + batch = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(device), batch) + + # Convert to float32 for memory efficiency (avoid float64) + batch['state'] = batch['state'].to(dtype=torch.float32) + batch['actions'] = batch['actions'].to(dtype=torch.float32) + + return batch + + +def get_model_state_dict(model): + """Get state dict from model, handling DDP wrapper.""" + return model.module.state_dict() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.state_dict() + + +def get_model_parameters(model): + """Get parameters from model, handling DDP wrapper.""" + return model.module.parameters() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.parameters() + + +def save_checkpoint(model, optimizer, global_step, config, is_main, ema_model=None): + """Save a checkpoint with model state, optimizer state, EMA state, and metadata.""" + if not is_main: + return + + # Only save if it's time to save or if it's the final step + if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: + # Ensure checkpoint_dir is a Path object and create the step-specific directory + ckpt_dir = config.checkpoint_dir / f"{global_step}" + ckpt_dir.mkdir(parents=True, exist_ok=True) + + # Save model state + state_dict = get_model_state_dict(model) + torch.save(state_dict, ckpt_dir / "pytorch_model.pt") + + # Save optimizer state + torch.save(optimizer.state_dict(), ckpt_dir / "optimizer.pt") + + # Save EMA state if available + if ema_model is not None: + torch.save(ema_model.state_dict(), ckpt_dir / "ema_model.pt") + + # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) + metadata = { + "global_step": global_step, + "config": dataclasses.asdict(config), + "timestamp": time.time(), + } + torch.save(metadata, ckpt_dir / "metadata.pt") + + logging.info(f"Saved checkpoint at step {global_step} -> {ckpt_dir}") + + # Log checkpoint to wandb + if config.wandb_enabled: + wandb.log({"checkpoint_step": global_step}, step=global_step) + + +def load_checkpoint(model, optimizer, checkpoint_dir, device, ema_model=None): + """Load the latest checkpoint and return the global step.""" + checkpoint_steps = [] + for d in checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + if not checkpoint_steps: + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + + latest_step = max(checkpoint_steps) + ckpt_dir = checkpoint_dir / f"{latest_step}" + + # Clear memory before loading checkpoints + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "before_loading_checkpoint") + + try: + # Load model state with error handling + logging.info("Loading model state...") + model_state_dict = torch.load(ckpt_dir / "pytorch_model.pt", map_location=device, weights_only=False) + (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model).load_state_dict(model_state_dict) + del model_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_model") + + # Load optimizer state with error handling + logging.info("Loading optimizer state...") + optimizer_state_dict = torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=False) + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_optimizer") + + # Load EMA state if available + if ema_model is not None and (ckpt_dir / "ema_model.pt").exists(): + logging.info("Loading EMA state...") + # Clear as much memory as possible before loading EMA + torch.cuda.empty_cache() + gc.collect() + + ema_state_dict = torch.load(ckpt_dir / "ema_model.pt", map_location=device, weights_only=False) + ema_model.load_state_dict(ema_state_dict) + del ema_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_ema") + logging.info(f"Successfully loaded EMA state from step {latest_step}") + + # Load metadata + logging.info("Loading metadata...") + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + global_step = metadata.get("global_step", latest_step) + del metadata + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_metadata") + + logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") + return global_step + + except RuntimeError as e: + if "out of memory" in str(e): + # Clear memory and provide detailed error message + torch.cuda.empty_cache() + gc.collect() + logging.error(f"Out of memory error while loading checkpoint: {str(e)}") + log_memory_usage(device, latest_step, "after_oom_error") + raise RuntimeError(f"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True") from e + raise + + +def get_latest_checkpoint_step(checkpoint_dir): + """Get the latest checkpoint step number from a checkpoint directory.""" + checkpoint_steps = [] + for d in checkpoint_dir.iterdir(): + if d.is_dir() and d.name.isdigit(): + checkpoint_steps.append(int(d.name)) + + return max(checkpoint_steps) if checkpoint_steps else None + + +def log_memory_usage(device, step, phase="unknown"): + """Log detailed memory usage information.""" + if not torch.cuda.is_available(): + return + + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 + memory_reserved = torch.cuda.memory_reserved(device) / 1e9 + memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) + memory_free = memory_free / 1e9 + + # Get more detailed memory info + memory_stats = torch.cuda.memory_stats(device) + max_memory_allocated = memory_stats.get('allocated_bytes.all.peak', 0) / 1e9 + max_memory_reserved = memory_stats.get('reserved_bytes.all.peak', 0) / 1e9 + + # Get DDP info if available + ddp_info = "" + if dist.is_initialized(): + ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" + + logging.info(f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}") + + +def train_loop(config: _config.TrainConfig): + use_ddp, local_rank, device = setup_ddp() + is_main = (not use_ddp) or (dist.get_rank() == 0) + set_seed(config.seed, local_rank) + + # Initialize checkpoint directory and wandb + resuming = False + if config.resume: + # Find checkpoint directory based on experiment name + exp_checkpoint_dir = config.checkpoint_dir + if exp_checkpoint_dir.exists(): + # Use validation to find the latest working checkpoint + latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) + if latest_step is not None: + resuming = True + logging.info(f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}") + else: + raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") + else: + raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") + elif config.overwrite and config.checkpoint_dir.exists(): + shutil.rmtree(config.checkpoint_dir) + logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") + + # Create checkpoint directory with experiment name + if not resuming: + # For new runs, create experiment-specific checkpoint directory + exp_checkpoint_dir = config.checkpoint_dir + exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) + logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") + else: + # For resume, checkpoint_dir is already set to the experiment directory + logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") + + # Initialize wandb (only on main process) + if is_main: + init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) + + # Build data loader using the unified data loader + # Calculate effective batch size per GPU for DDP + # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size + world_size = torch.distributed.get_world_size() if use_ddp else 1 + effective_batch_size = config.batch_size // world_size + logging.info(f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})") + + # Pass the original batch size to data loader - it will handle DDP splitting internally + loader, _ = build_datasets(config) + + # Log sample images to wandb on first batch + if is_main and config.wandb_enabled and not resuming: + # Create a separate data loader for sample batch to avoid consuming the main loader + sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) + sample_batch = next(iter(sample_data_loader)) + # Convert observation and actions to torch tensors + observation, actions = sample_batch + sample_batch = observation.to_dict() + sample_batch["actions"] = actions + sample_batch = batch_to_torch(sample_batch, device) + + # Create sample images for wandb + images_to_log = [] + # Get batch size from the first image tensor + batch_size = next(iter(sample_batch['image'].values())).shape[0] + for i in range(min(5, batch_size)): + # Concatenate all camera views horizontally for this batch item + # Convert from NCHW to NHWC format for wandb + img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch['image'].values()], axis=1) + img_concatenated = img_concatenated.cpu().numpy() + images_to_log.append(wandb.Image(img_concatenated)) + + wandb.log({"camera_views": images_to_log}, step=0) + + # Clear sample batch from memory + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Build model + if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): + # Convert dataclass to Pi0Config if needed + model_cfg = openpi.models.pi0_config.Pi0Config( + dtype=config.pytorch_training_precision, + action_dim=config.model.action_dim, + action_horizon=config.model.action_horizon, + max_token_len=config.model.max_token_len, + paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), + action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), + pi05=getattr(config.model, "pi05", False), + ) + else: + model_cfg = config.model + # Update dtype to match pytorch_training_precision + object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) + + model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) + + if hasattr(model, 'gradient_checkpointing_enable'): + enable_gradient_checkpointing = True + model.gradient_checkpointing_enable() + logging.info("Enabled gradient checkpointing for memory optimization") + else: + enable_gradient_checkpointing = False + logging.info("Gradient checkpointing is not supported for this model") + + # Log initial memory usage after model creation + if is_main and torch.cuda.is_available(): + log_memory_usage(device, 0, "after_model_creation") + + if use_ddp: + # Enable unused parameter detection to handle cases where some parameters don't participate in loss + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=True) + + # Load weights from weight_loader if specified (for fine-tuning) + if config.pytorch_weight_path is not None: + logging.info(f"Loading weights from: {config.pytorch_weight_path}") + + model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") + safetensors.torch.load_model((model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path) + logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") + + # Optimizer + learning rate schedule from config + warmup_steps = config.lr_schedule.warmup_steps + peak_lr = config.lr_schedule.peak_lr + decay_steps = config.lr_schedule.decay_steps + end_lr = config.lr_schedule.decay_lr + + # Create optimizer with config parameters + optim = torch.optim.AdamW( + model.parameters(), + lr=peak_lr, + betas=(config.optimizer.b1, config.optimizer.b2), + eps=config.optimizer.eps, + weight_decay=config.optimizer.weight_decay + ) + + # Initialize EMA if specified in config + ema_model = None + if config.ema_decay is not None: + ema_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) + + # Get the correct state dict from the main model + main_model_state_dict = get_model_state_dict(model) + + # Load the state dict into EMA model + ema_model.load_state_dict(main_model_state_dict) + ema_model.eval() + logging.info(f"Initialized EMA with decay {config.ema_decay}") + + # Load checkpoint if resuming + global_step = 0 + if resuming: + global_step = load_checkpoint(model, optim, config.checkpoint_dir, device, ema_model) + logging.info(f"Resumed training from step {global_step}") + + def lr_schedule(step: int): + if step < warmup_steps: + # Match JAX behavior: start from peak_lr / (warmup_steps + 1) + init_lr = peak_lr / (warmup_steps + 1) + return init_lr + (peak_lr - init_lr) * step / warmup_steps + # cosine decay + progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) + cos = 0.5 * (1 + np.cos(np.pi * progress)) + return end_lr + (peak_lr - end_lr) * cos + + model.train() + start_time = time.time() + infos = [] # Collect stats over log interval + if is_main: + logging.info(f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}") + logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}") + logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") + logging.info(f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}") + logging.info(f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}") + logging.info(f"Training precision: {model_cfg.dtype}") + if config.ema_decay is not None: + logging.info(f"EMA decay: {config.ema_decay}") + + # Training loop - iterate until we reach num_train_steps + pbar = tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None + + while global_step < config.num_train_steps: + # Set epoch for distributed training + if use_ddp and hasattr(loader, 'set_epoch'): + loader.set_epoch(global_step // len(loader)) + + for batch in loader: + # Check if we've reached the target number of steps + if global_step >= config.num_train_steps: + break + + # The unified data loader returns (observation, actions) tuple + observation, actions = batch + + # Convert observation and actions to torch tensors + observation_dict = observation.to_dict() + observation_dict["actions"] = actions + batch_torch = batch_to_torch(observation_dict, device) + actions = batch_torch["actions"] + + # Update LR + for pg in optim.param_groups: + pg["lr"] = lr_schedule(global_step) + + # Forward pass + observation = _model.Observation.from_dict(batch_torch) + losses = model(observation, actions) + # Ensure losses is a tensor and handle different return types + if isinstance(losses, (list, tuple)): + losses = torch.stack(losses) + elif not isinstance(losses, torch.Tensor): + losses = torch.tensor(losses, device=device, dtype=torch.float32) + + loss = losses.mean() + + # Backward pass + loss.backward() + + # Log memory usage after backward pass + if global_step < 5 and is_main: + if torch.cuda.is_available(): + log_memory_usage(device, global_step, "after_backward") + + # Gradient clipping + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) + + # Optimizer step + optim.step() + optim.zero_grad(set_to_none=True) + + # Clear gradients more aggressively + for param in model.parameters(): + if param.grad is not None: + param.grad.detach_() + param.grad = None + + # Update EMA if enabled + if ema_model is not None: + try: + with torch.no_grad(): + # Get parameters from the correct model structure + main_model_params = get_model_parameters(model) + for param, ema_param in zip(main_model_params, ema_model.parameters()): + ema_param.data.mul_(config.ema_decay).add_(param.data, alpha=1 - config.ema_decay) + except Exception as e: + logging.warning(f"Failed to update EMA model: {e}") + # Continue training without EMA update + + # Collect stats + if is_main: + infos.append({ + "loss": loss.item(), + "learning_rate": optim.param_groups[0]['lr'], + "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, + }) + + if is_main and (global_step % config.log_interval == 0): + elapsed = time.time() - start_time + + # Average stats over log interval + avg_loss = sum(info["loss"] for info in infos) / len(infos) + avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) + + avg_grad_norm = None + if any('grad_norm' in info for info in infos): + vals = [info['grad_norm'] for info in infos if 'grad_norm' in info and info['grad_norm'] is not None] + if len(vals) > 0: + avg_grad_norm = sum(vals) / len(vals) + logging.info(f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" if avg_grad_norm is not None else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s") + + # Log to wandb + if config.wandb_enabled and len(infos) > 0: + log_payload = { + "loss": avg_loss, + "learning_rate": avg_lr, + "step": global_step, + "time_per_step": elapsed / config.log_interval, + } + if avg_grad_norm is not None: + log_payload["grad_norm"] = avg_grad_norm + wandb.log(log_payload, step=global_step) + + start_time = time.time() + infos = [] # Reset stats collection + + # Save checkpoint using the new mechanism + save_checkpoint(model, optim, global_step, config, is_main, ema_model) + + global_step += 1 + + # Update progress bar + if pbar is not None: + pbar.update(1) + pbar.set_postfix({ + 'loss': f'{loss.item():.4f}', + 'lr': f'{optim.param_groups[0]["lr"]:.2e}', + 'step': global_step + }) + + # Close progress bar + if pbar is not None: + pbar.close() + + # Finish wandb run + if is_main and config.wandb_enabled: + wandb.finish() + + cleanup_ddp() + + +def main(): + init_logging() + config = _config.cli() + train_loop(config) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_single_example.py b/scripts/train_single_example.py new file mode 100644 index 0000000..b919ff7 --- /dev/null +++ b/scripts/train_single_example.py @@ -0,0 +1,664 @@ +#!/usr/bin/env python3 +""" +Train on a single example for debugging JAX vs PyTorch comparison. + +This script creates a deterministic dataset with one example and trains on it +to help debug differences between JAX and PyTorch implementations. + +Usage examples: + +# Test pi05_droid model +python scripts/train_single_example.py \ + --model_name pi05_droid \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch + +# Test pi0_droid model +python scripts/train_single_example.py \ + --model_name pi0_droid \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi0_droid \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi0_droid_pytorch + +# Test pi0_aloha_sim model +python scripts/train_single_example.py \ + --model_name pi0_aloha_sim \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi0_aloha_sim \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi0_aloha_sim_pytorch + +# Test pi05_libero model with pickle file (FASTEST for debugging) +python scripts/train_single_example.py \ + --model_name pi05_libero \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_libero \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_libero_pytorch \ + --load_pickle ./libero_sample.pkl + +# Test pi05_libero model with small dataset (RECOMMENDED for first-time setup) +python scripts/train_single_example.py \ + --model_name pi05_libero \ + --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_libero \ + --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_libero_pytorch \ + +Data Loading Options (in order of speed): +1. --load_pickle: Load from pickle file (instant, fastest for repeated debugging) + + +Setup Workflow: +1. First time: Run save_libero_sample.py to create pickle file (takes 10-30 minutes) +2. Subsequent runs: Use --load_pickle for instant loading (takes ~1 second) + +This script: +- Loads a real example from the pi05_libero dataset using JAX data loader +- Uses the same noise and time values for both JAX and PyTorch +- Disables preprocessing for fair comparison +- Compares losses between implementations +- Tests forward pass, backward pass, and another forward pass +- Provides detailed analysis of differences +""" + +import argparse +import logging +import numpy as np +import torch +import jax +import jax.numpy as jnp +import flax.nnx as nnx +import flax +import safetensors +from unittest.mock import patch +import time +import signal +import sys +import optax + +import openpi.models.model as _model +from openpi.models.pi0_config import Pi0Config +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch +import openpi.training.config +from openpi.training.data_loader import create_data_loader +import openpi.training.optimizer +from openpi.shared import nnx_utils + + +def setup_logging(): + """Setup logging for debugging.""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' + ) + + +def load_sample_from_pickle(pickle_path: str): + """Load a sample from a pickle file saved by save_libero_sample.py.""" + print(f"🔄 Loading sample from pickle file: {pickle_path}") + start_time = time.time() + + try: + import pickle + with open(pickle_path, 'rb') as f: + sample_data = pickle.load(f) + + elapsed_time = time.time() - start_time + print(f"✅ Sample loaded successfully in {elapsed_time:.2f} seconds") + + # Extract the data in the same format as create_fixed_example + observation_data = sample_data['observation'] + + # Return in the same format as create_fixed_example + example = { + "image": observation_data.get("image", {}), + "image_mask": observation_data.get("image_mask", {}), + "state": observation_data.get("state", np.array([])), + "actions": observation_data.get("actions", np.array([])), + "tokenized_prompt": observation_data.get("tokenized_prompt", np.array([])), + "tokenized_prompt_mask": observation_data.get("tokenized_prompt_mask", np.array([])), + } + + print(f" - Image keys: {list(example['image'].keys()) if example['image'] else 'None'}") + print(f" - State shape: {example['state'].shape}") + print(f" - Actions shape: {example['actions'].shape}") + + return example + + except Exception as e: + elapsed_time = time.time() - start_time + print(f"❌ Failed to load pickle file after {elapsed_time:.2f} seconds: {e}") + import traceback + traceback.print_exc() + raise + + +def create_fixed_example(model_config): + """Create a fixed example for debugging.""" + np.random.seed(42) + + batch_size = 1 + action_dim = model_config.action_dim + action_horizon = model_config.action_horizon + image_size = 224 + max_token_len = model_config.max_token_len + + # Create fixed images + images = {} + for key in _model.IMAGE_KEYS: + img = np.zeros((batch_size, image_size, image_size, 3), dtype=np.float32) + + # Simple gradient pattern + for i in range(image_size): + for j in range(image_size): + val = (i + j) / (2 * image_size) * 2 - 1 + img[0, i, j, :] = [val, val * 0.5, val * 0.25] + + images[key] = img + + # Create fixed state and actions + state = np.random.randn(batch_size, action_dim).astype(np.float32) * 0.1 + actions = np.random.randn(batch_size, action_horizon, action_dim).astype(np.float32) * 0.1 + + # Create fixed language tokens + tokenized_prompt = np.random.randint(0, 1000, (batch_size, max_token_len), dtype=np.int32) + tokenized_prompt_mask = np.ones((batch_size, max_token_len), dtype=bool) + + # Create image masks + image_masks = {key: np.ones(batch_size, dtype=bool) for key in _model.IMAGE_KEYS} + + return { + "image": images, + "image_mask": image_masks, + "state": state, + "actions": actions, + "tokenized_prompt": tokenized_prompt, + "tokenized_prompt_mask": tokenized_prompt_mask, + } + + +def create_fixed_noise_and_time(batch_size, action_horizon, action_dim): + """Create fixed noise and time values for deterministic comparison.""" + np.random.seed(42) # Use same seed for consistency + + # Create fixed noise + noise = np.random.randn(batch_size, action_horizon, action_dim).astype(np.float32) * 0.1 + + # Create fixed time values (beta distribution like in the models) + time_beta = np.random.beta(1.5, 1.0, batch_size).astype(np.float32) + time = time_beta * 0.999 + 0.001 + + return noise, time + + +def mock_preprocess_observation(rng, observation, **kwargs): + """Mock function that returns observation unchanged to disable preprocessing.""" + return observation + + +def mock_preprocess_observation_pytorch(observation, **kwargs): + """Mock function that returns observation unchanged to disable preprocessing.""" + return observation + + +def test_pytorch_single_example(noise, time, model_name, pytorch_checkpoint_dir, load_pickle=None): + """Test PyTorch training on single example.""" + print("\n=== Testing PyTorch on Single Example ===") + + # Create model using the training config + train_config = openpi.training.config.get_config(model_name) + model = PI0Pytorch(train_config.model) # Use train_config.model instead of train_config + + # Load pre-trained weights + pytorch_checkpoint_dir = pytorch_checkpoint_dir + "/model.safetensors" + print(f"Loading PyTorch weights from: {pytorch_checkpoint_dir}") + + safetensors.torch.load_model(model, pytorch_checkpoint_dir) + + # Load data based on arguments + if load_pickle: + print(f"🎯 Loading sample from pickle file: {load_pickle}") + example = load_sample_from_pickle(load_pickle) + else: + print("🎯 Using fixed example for testing") + example = create_fixed_example(train_config.model) + + # Convert to PyTorch tensors + pytorch_example = {} + for key, value in example.items(): + if key == "image": + # Convert channels-last [B, H, W, C] to channels-first [B, C, H, W] for PyTorch + pytorch_example[key] = {} + for k, v in value.items(): + # v is [B, H, W, C], convert to [B, C, H, W] + v_tensor = torch.from_numpy(v) + v_tensor = v_tensor.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + pytorch_example[key][k] = v_tensor + elif key == "image_mask": + pytorch_example[key] = {k: torch.from_numpy(v) for k, v in value.items()} + else: + pytorch_example[key] = torch.from_numpy(value) + + # Convert noise and time to PyTorch tensors + noise_tensor = torch.from_numpy(noise) + time_tensor = torch.from_numpy(time) + + # Create observation + observation_torch = _model.Observation.from_dict(pytorch_example) + actions_torch = pytorch_example["actions"] + + print(f"Observation state shape: {observation_torch.state.shape}") + print(f"Observation state dtype: {observation_torch.state.dtype}") + print(f"Actions shape: {actions_torch.shape}") + print(f"Actions dtype: {actions_torch.dtype}") + print(f"Noise shape: {noise_tensor.shape}, dtype: {noise_tensor.dtype}") + print(f"Time shape: {time_tensor.shape}, dtype: {time_tensor.dtype}") + + # Setup optimizer from config + print(f"Setting up optimizer from config: {type(train_config.optimizer).__name__}") + # Use the exact same optimizer creation code as in train_pytorch.py + warmup_steps = train_config.lr_schedule.warmup_steps + peak_lr = train_config.lr_schedule.peak_lr + decay_steps = train_config.lr_schedule.decay_steps + end_lr = train_config.lr_schedule.decay_lr + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=peak_lr, + betas=(train_config.optimizer.b1, train_config.optimizer.b2), + eps=train_config.optimizer.eps, + weight_decay=train_config.optimizer.weight_decay + ) + print(f"✅ Optimizer created: {type(optimizer).__name__}") + print(f" - Learning rate: {peak_lr}") + print(f" - Betas: ({train_config.optimizer.b1}, {train_config.optimizer.b2})") + print(f" - Epsilon: {train_config.optimizer.eps}") + print(f" - Weight decay: {train_config.optimizer.weight_decay}") + + # Define learning rate schedule function (same as in train_pytorch.py) + def lr_schedule(step: int): + if step < warmup_steps: + # Match JAX behavior: start from peak_lr / (warmup_steps + 1) + init_lr = peak_lr / (warmup_steps + 1) + return init_lr + (peak_lr - init_lr) * step / warmup_steps + # cosine decay + progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) + cos = 0.5 * (1 + np.cos(np.pi * progress)) + return end_lr + (peak_lr - end_lr) * cos + + # Test forward pass, backward pass, and another forward pass + model.train() # Set to training mode for backward pass + + try: + # Mock preprocess_observation_pytorch to avoid augmentations + with patch('openpi.models_pytorch.preprocessing_pytorch.preprocess_observation_pytorch', side_effect=mock_preprocess_observation_pytorch): + # First forward pass + print("🔄 First forward pass...") + losses_1 = model(observation_torch, actions_torch, noise=noise_tensor, time=time_tensor) + loss_1 = losses_1.mean() + print(f"First forward pass successful! Loss: {loss_1.item():.6f}") + + # Backward pass + print("🔄 Backward pass...") + loss_1.backward() + print("Backward pass successful!") + + # Check gradients + total_grad_norm = 0 + param_count = 0 + for param in model.parameters(): + if param.grad is not None: + total_grad_norm += param.grad.norm().item() ** 2 + param_count += 1 + total_grad_norm = total_grad_norm ** 0.5 + print(f"Gradient norm: {total_grad_norm:.6f} (from {param_count} parameters)") + + # Gradient clipping (same as in train_pytorch.py) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=train_config.optimizer.clip_gradient_norm) + print(f" - Gradient clipping applied, clipped norm: {grad_norm:.6f}") + + # Optimizer step to update parameters + print("🔄 Optimizer step...") + # Update learning rate using the schedule (same as in train_pytorch.py) + current_lr = lr_schedule(0) # Use step 0 for this test + for pg in optimizer.param_groups: + pg["lr"] = current_lr + print(f" - Updated learning rate to: {current_lr:.2e}") + + optimizer.step() + print("Optimizer step successful!") + + # Clear gradients for next iteration + optimizer.zero_grad() + + # Second forward pass + print("🔄 Second forward pass...") + losses_2 = model(observation_torch, actions_torch, noise=noise_tensor, time=time_tensor) + loss_2 = losses_2.mean() + print(f"Second forward pass successful! Loss: {loss_2.item():.6f}") + + # Compare losses + loss_diff = abs(loss_1.item() - loss_2.item()) + print(f"Loss difference between passes: {loss_diff:.8f}") + + return True, losses_1, losses_2 + + except Exception as e: + print(f"PyTorch forward/backward pass failed: {e}") + return False, None, None + + +def test_jax_single_example(noise, time, model_name, jax_checkpoint_dir, load_pickle=None): + """Test JAX training on single example.""" + print("\n=== Testing JAX on Single Example ===") + + # Create model using the training config + train_config = openpi.training.config.get_config(model_name) + rng = jax.random.key(42) + model = train_config.model.create(rng) # Use train_config.model instead of config.create + + jax_checkpoint_dir = jax_checkpoint_dir + '/params/' + # Load pre-trained weights + print(f"Loading JAX weights from: {jax_checkpoint_dir}") + params = _model.restore_params(jax_checkpoint_dir, dtype=jnp.bfloat16) + params_to_use = params + print("✅ JAX weights loaded successfully!") + + # Apply the params to the model using NNX state management + import flax.nnx as nnx + graphdef, model_state = nnx.split(model) + + # Now try to load parameters + try: + model_state.replace_by_pure_dict(params_to_use) + model = nnx.merge(graphdef, model_state) + print("✅ Parameters loaded successfully!") + except Exception as e: + print(f"❌ Parameter loading failed: {e}") + print("🔄 Continuing with random initialization...") + model = nnx.merge(graphdef, model_state) + + # Load data based on arguments + if load_pickle: + print(f"🎯 Loading sample from pickle file: {load_pickle}") + example = load_sample_from_pickle(load_pickle) + else: + print("🎯 Using fixed example for testing") + example = create_fixed_example(train_config.model) + + # Convert to JAX arrays + jax_example = {} + for key, value in example.items(): + if key == "image": + jax_example[key] = {k: jnp.array(v) for k, v in value.items()} + elif key == "image_mask": + jax_example[key] = {k: jnp.array(v) for k, v in value.items()} + else: + jax_example[key] = jnp.array(value) + + # Convert noise and time to JAX arrays + noise_jax = jnp.array(noise) + time_jax = jnp.array(time) + + # Create observation + observation = _model.Observation.from_dict(jax_example) + actions = jax_example["actions"] + + print(f"Observation state shape: {observation.state.shape}") + print(f"Observation state dtype: {observation.state.dtype}") + print(f"Actions shape: {actions.shape}") + print(f"Actions dtype: {actions.dtype}") + print(f"Noise shape: {noise_jax.shape}, dtype: {noise_jax.dtype}") + print(f"Time shape: {time_jax.shape}, dtype: {time_jax.dtype}") + + # Get learning rate from config + lr = train_config.lr_schedule.peak_lr if hasattr(train_config.lr_schedule, 'peak_lr') else 1e-4 + print(f"Using learning rate from config: {lr}") + + # Create optimizer using the exact same code as in train.py + print("Setting up JAX optimizer from config...") + tx = openpi.training.optimizer.create_optimizer(train_config.optimizer, train_config.lr_schedule, weight_decay_mask=None) + print(f"✅ JAX optimizer created: {type(tx).__name__}") + + # Initialize optimizer state + # Get trainable parameters from the model using NNX + params = nnx.state(model) + trainable_params = params.filter(train_config.trainable_filter) + opt_state = tx.init(trainable_params) + print(f"✅ Optimizer state initialized") + + # Test forward pass, backward pass, and another forward pass + try: + # Use mock to disable preprocessing + with patch('openpi.models.model.preprocess_observation', side_effect=mock_preprocess_observation): + # JIT compile the compute_loss method for memory efficiency + print("🔄 JIT compiling compute_loss method...") + jitted_compute_loss = nnx_utils.module_jit(model.compute_loss) + + # First forward pass + print("🔄 First forward pass...") + losses_1 = jitted_compute_loss(rng, observation, actions, train=True, noise=noise_jax, time=time_jax) + loss_1 = losses_1.mean() + print(f"First forward pass successful! Loss: {loss_1.item():.6f}") + + # Use the same approach as in train.py for gradient computation and parameter updates + print("🔄 Computing gradients and updating parameters...") + + # Define loss function for gradient computation using JIT compiled method + def loss_fn(model, rng, observation, actions): + chunked_loss = jitted_compute_loss(rng, observation, actions, train=True, noise=noise_jax, time=time_jax) + return jnp.mean(chunked_loss) + + # Filter out frozen params and compute gradients (same as train.py) + diff_state = nnx.DiffState(0, train_config.trainable_filter) + loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, rng, observation, actions) + + print("Gradients computed successfully!") + + # Check gradient norms using JAX's global_norm + try: + grad_norm = optax.global_norm(grads) + print(f"Gradient norm: {grad_norm.item():.6f}") + except Exception as e: + print(f"Could not compute gradient norm: {e}") + print("Continuing without gradient norm...") + + # Update parameters using optimizer (same as train.py) + print("🔄 Updating parameters with optimizer...") + params = nnx.state(model).filter(train_config.trainable_filter) + updates, new_opt_state = tx.update(grads, opt_state, params) + new_params = optax.apply_updates(params, updates) + + # Update the model in place (same as train.py) + nnx.update(model, new_params) + opt_state = new_opt_state + print("Parameter update successful!") + + # Second forward pass + print("🔄 Second forward pass...") + losses_2 = jitted_compute_loss(rng, observation, actions, train=True, noise=noise_jax, time=time_jax) + loss_2 = losses_2.mean() + print(f"Second forward pass successful! Loss: {loss_2.item():.6f}") + + # Compare losses + loss_diff = abs(loss_1.item() - loss_2.item()) + print(f"Loss difference between passes: {loss_diff:.8f}") + + return True, losses_1, losses_2 + + except Exception as e: + print(f"JAX forward/backward pass failed: {e}") + import traceback + traceback.print_exc() + return False, None, None + + +def compare_losses(pytorch_loss_1, pytorch_loss_2, jax_loss_1, jax_loss_2): + """Compare losses and compute relative differences.""" + if pytorch_loss_1 is None or jax_loss_1 is None: + return + + print("\n" + "=" * 70) + print("📊 LOSS COMPARISON") + print("=" * 70) + + print(f"PyTorch first pass loss: {pytorch_loss_1}") + print(f"PyTorch second pass loss: {pytorch_loss_2}") + print(f"JAX first pass loss: {jax_loss_1}") + print(f"JAX second pass loss: {jax_loss_2}") + + # Additional tensor analysis if both are tensors + pytorch_loss_1 = pytorch_loss_1.to(torch.float32) + pytorch_loss_2 = pytorch_loss_2.to(torch.float32) + jax_loss_1 = jax_loss_1.astype(jnp.float32) + jax_loss_2 = jax_loss_2.astype(jnp.float32) + + if hasattr(pytorch_loss_1, 'shape') and hasattr(jax_loss_1, 'shape'): + print(f"\n📐 Tensor Analysis:") + + # Check if shapes match + if pytorch_loss_1.shape == jax_loss_1.shape: + print(f"✅ Tensor shapes match: {pytorch_loss_1.shape}") + + # Element-wise comparison + if hasattr(pytorch_loss_1, 'flatten') and hasattr(jax_loss_1, 'flatten'): + # Convert to numpy for element-wise analysis + try: + pytorch_flat_1 = pytorch_loss_1.detach().cpu().numpy().flatten() + pytorch_flat_2 = pytorch_loss_2.detach().cpu().numpy().flatten() + jax_flat_1 = jax_loss_1.flatten() + jax_flat_2 = jax_loss_2.flatten() + + # Element-wise differences between implementations + element_diff_1 = np.abs(pytorch_flat_1 - jax_flat_1) + element_diff_2 = np.abs(pytorch_flat_2 - jax_flat_2) + + max_element_diff_1 = np.max(element_diff_1) + mean_element_diff_1 = np.mean(element_diff_1) + max_element_diff_2 = np.max(element_diff_2) + mean_element_diff_2 = np.mean(element_diff_2) + + print(f" First pass - Max element-wise difference: {max_element_diff_1:.8f}") + print(f" First pass - Mean element-wise difference: {mean_element_diff_1:.8f}") + print(f" Second pass - Max element-wise difference: {max_element_diff_2:.8f}") + print(f" Second pass - Mean element-wise difference: {mean_element_diff_2:.8f}") + + # Element-wise relative differences + # Avoid division by zero by adding small epsilon + epsilon = 1e-12 + pytorch_flat_1_safe = pytorch_flat_1 + epsilon + jax_flat_1_safe = jax_flat_1 + epsilon + pytorch_flat_2_safe = pytorch_flat_2 + epsilon + jax_flat_2_safe = jax_flat_2 + epsilon + + # Compute relative differences for each element + rel_diff_pytorch_1 = (element_diff_1 / np.abs(pytorch_flat_1_safe)) * 100 + rel_diff_jax_1 = (element_diff_1 / np.abs(jax_flat_1_safe)) * 100 + rel_diff_pytorch_2 = (element_diff_2 / np.abs(pytorch_flat_2_safe)) * 100 + rel_diff_jax_2 = (element_diff_2 / np.abs(jax_flat_2_safe)) * 100 + + # Compute mean of relative differences + mean_rel_diff_pytorch_1 = np.mean(rel_diff_pytorch_1) + mean_rel_diff_jax_1 = np.mean(rel_diff_jax_1) + mean_rel_diff_pytorch_2 = np.mean(rel_diff_pytorch_2) + mean_rel_diff_jax_2 = np.mean(rel_diff_jax_2) + + print(f" First pass - Mean relative difference (w.r.t. PyTorch): {mean_rel_diff_pytorch_1:.4f}%") + print(f" First pass - Mean relative difference (w.r.t. JAX): {mean_rel_diff_jax_1:.4f}%") + print(f" Second pass - Mean relative difference (w.r.t. PyTorch): {mean_rel_diff_pytorch_2:.4f}%") + print(f" Second pass - Mean relative difference (w.r.t. JAX): {mean_rel_diff_jax_2:.4f}%") + + # Count elements with significant differences + significant_threshold = 1e-4 + significant_count_1 = np.sum(element_diff_1 > significant_threshold) + significant_count_2 = np.sum(element_diff_2 > significant_threshold) + total_elements = len(element_diff_1) + significant_percentage_1 = (significant_count_1 / total_elements) * 100 + significant_percentage_2 = (significant_count_2 / total_elements) * 100 + + print(f" First pass - Elements with diff > {significant_threshold}: {significant_count_1}/{total_elements} ({significant_percentage_1:.2f}%)") + print(f" Second pass - Elements with diff > {significant_threshold}: {significant_count_2}/{total_elements} ({significant_percentage_2:.2f}%)") + + except Exception as e: + print(f" ⚠️ Could not perform element-wise analysis: {e}") + else: + print(f"❌ Tensor shapes don't match: PyTorch {pytorch_loss_1.shape} vs JAX {jax_loss_1.shape}") + + +def main(): + """Main function to test both implementations.""" + parser = argparse.ArgumentParser(description="Train on a single example for JAX vs PyTorch comparison") + parser.add_argument("--model_name", type=str, default="pi05_libero", + choices=["pi0_aloha_sim", "pi0_aloha_towel", "pi0_base", "pi05_droid", "pi0_droid", "pi0_libero", "pi05_libero"], + help="Model name to use") + parser.add_argument("--jax_checkpoint_dir", type=str, required=True, + help="Directory containing JAX model checkpoints") + parser.add_argument("--pytorch_checkpoint_dir", type=str, required=True, + help="Directory containing PyTorch model checkpoints") + parser.add_argument("--load_pickle", type=str, default=None, + help="Load sample from pickle file (fastest option)") + args = parser.parse_args() + + setup_logging() + + print("🚀 Testing Single Example Training for JAX vs PyTorch Comparison") + print("=" * 70) + print(f"📁 Model: {args.model_name}") + print(f"📁 JAX checkpoint: {args.jax_checkpoint_dir}") + print(f"📁 PyTorch checkpoint: {args.pytorch_checkpoint_dir}") + + print("🚫 Preprocessing disabled: Image augmentations and resizing are bypassed for fair comparison...") + + # Get model configuration + train_config = openpi.training.config.get_config(args.model_name) + model_config = train_config.model + + # Generate fixed noise and time + noise, time = create_fixed_noise_and_time( + batch_size=1, + action_horizon=model_config.action_horizon, + action_dim=model_config.action_dim + ) + + # Test JAX + jax_success, jax_losses_1, jax_losses_2 = test_jax_single_example( + noise, time, args.model_name, args.jax_checkpoint_dir, args.load_pickle + ) + + # Clear JAX memory + jax.clear_caches() + + # Test PyTorch + pytorch_success, pytorch_losses_1, pytorch_losses_2 = test_pytorch_single_example( + noise, time, args.model_name, args.pytorch_checkpoint_dir, args.load_pickle + ) + torch.cuda.empty_cache() + + + + # Compare losses + if pytorch_success and jax_success: + compare_losses(pytorch_losses_1, pytorch_losses_2, jax_losses_1, jax_losses_2) + + # Summary + print("\n" + "=" * 70) + print("📊 SUMMARY") + print("=" * 70) + + if pytorch_success and jax_success: + print("✅ Both JAX and PyTorch implementations work on the single example!") + print("✅ Forward pass, backward pass, and second forward pass completed successfully!") + print("🔍 Loss comparison completed above.") + elif pytorch_success: + print("❌ PyTorch works but JAX failed. Check JAX implementation.") + elif jax_success: + print("❌ JAX works but PyTorch failed. Check PyTorch implementation.") + else: + print("❌ Both implementations failed. Check the error messages above.") + + print("\n💡 Next steps:") + print("1. Run this script to verify both implementations work") + print("2. Analyze the loss comparison results above") + print("3. Check if the backward pass gradients are reasonable") + print("4. If losses differ significantly, investigate the differences") + print("5. Check if the noise and time handling is consistent between implementations") + print("6. Use the same example in full training runs") + print("7. Note: Preprocessing (image augmentations) is disabled for this comparison") + + +if __name__ == "__main__": + main() diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index f6ee554..e5b879b 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -4,7 +4,7 @@ import enum import logging import pathlib -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Union import augmax from flax import nnx @@ -24,7 +24,8 @@ logger = logging.getLogger("openpi") -ArrayT = TypeVar("ArrayT", at.Array, jax.ShapeDtypeStruct) +# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays) +ArrayT = TypeVar("ArrayT", bound=Union[jax.Array, torch.Tensor, np.ndarray]) class ModelType(enum.Enum): diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 6c6de37..e6d7f37 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -2,12 +2,14 @@ import torch from torch import nn from transformers import GemmaForCausalLM, PaliGemmaForConditionalGeneration +from transformers.models.gemma import modeling_gemma from transformers.models.auto import CONFIG_MAPPING +from typing import Literal class PaliGemmaWithExpertModel(nn.Module): - def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): + def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False], precision: Literal["bfloat16", "float32"] = "bfloat16"): super().__init__() vlm_config_hf = CONFIG_MAPPING["paligemma"]() @@ -47,10 +49,16 @@ def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False]): self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None - self.to_bfloat16_for_selected_params() + self.to_bfloat16_for_selected_params(precision) - def to_bfloat16_for_selected_params(self): - self = self.to(dtype=torch.bfloat16) + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self = self.to(dtype=torch.bfloat16) + elif precision == "float32": + self = self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") params_to_keep_float32 = [ "vision_tower.vision_model.embeddings.patch_embedding.weight", @@ -78,9 +86,9 @@ def forward( past_key_values: list[torch.FloatTensor] | Cache | None = None, inputs_embeds: list[torch.FloatTensor] = None, use_cache: bool | None = None, - adarms_cond: list[torch.Tensor] | None = None, + adarms_cond: list[torch.Tensor] = [None, None], ): - if inputs_embeds[0] is not None: + if inputs_embeds[1] is None: prefix_output = self.paligemma.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, @@ -91,11 +99,8 @@ def forward( ) prefix_past_key_values = prefix_output.past_key_values prefix_output = prefix_output.last_hidden_state - else: - prefix_output = None - prefix_past_key_values = None - - if inputs_embeds[1] is not None: + suffix_output = None + elif inputs_embeds[0] is None: suffix_output = self.gemma_expert.model.forward( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, @@ -105,7 +110,145 @@ def forward( adarms_cond=adarms_cond[1] if adarms_cond is not None else None, ) suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None else: - suffix_output = None + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, 'gradient_checkpointing') and + self.gemma_expert.model.gradient_checkpointing and + self.training + ) or ( + hasattr(self, 'gradient_checkpointing') and + self.gradient_checkpointing and + self.training + ) + + # Force enable gradient checkpointing if we're in training mode and the model supports it + if self.training and hasattr(self.gemma_expert.model, 'gradient_checkpointing'): + if not self.gemma_expert.model.gradient_checkpointing: + print("Forcing gradient checkpointing to be enabled for Gemma expert model") + self.gemma_expert.model.gradient_checkpointing = True + use_gradient_checkpointing = True + + # Debug gradient checkpointing status + if hasattr(self, '_debug_gc_printed') and not self._debug_gc_printed: + print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}") + print(f"Model training mode: {self.training}") + print(f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}") + if hasattr(self.gemma_expert.model, 'gradient_checkpointing'): + print(f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}") + self._debug_gc_printed = True + + # Define the complete layer computation function for gradient checkpointing + def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): + models = [self.paligemma.language_model, self.gemma_expert.model] + + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) + gates.append(gate) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + + dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype) + cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) + + batch_size = query_states.shape[0] + scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling + + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + self.paligemma.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling + ) + # Get head_dim from the current layer, not from the model + head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) + outputs_embeds.append(out_emb) + start_pos = end_pos + + return outputs_embeds + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, + use_reentrant=False, + preserve_rng_state=False + ) + else: + inputs_embeds = compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond) + + # Old code removed - now using compute_layer_complete function above + + # final norm + # Define final norm computation function for gradient checkpointing + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, adarms_cond, + use_reentrant=False, + preserve_rng_state=False + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None return [prefix_output, suffix_output], prefix_past_key_values \ No newline at end of file diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 9a6c0ea..dad7aae 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -1,4 +1,5 @@ import math +import logging import torch from torch import Tensor @@ -7,6 +8,7 @@ import openpi.models.gemma as _gemma from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel +import openpi.models_pytorch.preprocessing_pytorch as _preprocessing def get_safe_dtype(target_dtype, device_type): @@ -42,9 +44,11 @@ def create_sinusoidal_pos_embedding( def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + samples = dist.sample((bsize,)) + return samples def make_att_2d_masks(pad_masks, att_masks): @@ -89,7 +93,7 @@ def __init__(self, config): paligemma_config = _gemma.get_config(config.paligemma_variant) action_expert_config = _gemma.get_config(config.action_expert_variant) - self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_config, action_expert_config, use_adarms=[False, True] if self.pi05 else [False, False]) + self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_config, action_expert_config, use_adarms=[False, True] if self.pi05 else [False, False], precision=config.dtype) self.action_in_proj = nn.Linear(32, action_expert_config.width) self.action_out_proj = nn.Linear(action_expert_config.width, 32) @@ -104,6 +108,62 @@ def __init__(self, config): torch.set_float32_matmul_precision('high') self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + try: + from transformers.models.siglip import check + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError("TransformersReplace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`.") + except ImportError: + raise ValueError("TransformersReplace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`.") + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def is_gradient_checkpointing_enabled(self): + """Check if gradient checkpointing is enabled.""" + return self.gradient_checkpointing_enabled + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + else: + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) + + def _preprocess_observation(self, observation, train=True): + """Helper method to preprocess observation.""" + observation = _preprocessing.preprocess_observation_pytorch(observation, train=train) + return ( + list(observation.images.values()), + list(observation.image_masks.values()), + observation.tokenized_prompt, + observation.tokenized_prompt_mask, + observation.state + ) def sample_noise(self, shape, device): noise = torch.normal( @@ -130,11 +190,12 @@ def embed_prefix( pad_masks = [] att_masks = [] - for ( - img, - img_mask, - ) in zip(images, img_masks): - img_emb = self.paligemma_with_expert.embed_image(img) + # Process images + for img, img_mask in zip(images, img_masks): + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) bsize, num_img_embs = img_emb.shape[:2] img_mask = img_mask[:, None].expand(bsize, num_img_embs) @@ -145,10 +206,13 @@ def embed_prefix( # Create attention masks so that image tokens attend to each other att_masks += [0] * num_img_embs - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) + # Process language tokens + def lang_embed_func(lang_tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) pad_masks.append(lang_masks) @@ -163,7 +227,6 @@ def embed_prefix( # Get batch size from the first dimension of the concatenated tensors bsize = pad_masks.shape[0] - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks @@ -176,7 +239,11 @@ def embed_suffix(self, state, noisy_actions, timestep): if not self.pi05: # Embed state - state_emb = self.state_proj(state) + def state_proj_func(state): + return self.state_proj(state) + + state_emb = self._apply_checkpoint(state_proj_func, state) + embs.append(state_emb[:, None, :]) bsize = state_emb.shape[0] dtype = state_emb.dtype @@ -195,21 +262,34 @@ def embed_suffix(self, state, noisy_actions, timestep): time_emb = time_emb.type(dtype=timestep.dtype) # Fuse timestep + action information using an MLP - action_emb = self.action_in_proj(noisy_actions) + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) if not self.pi05: time_emb = time_emb[:, None, :].expand_as(action_emb) action_time_emb = torch.cat([action_emb, time_emb], dim=2) - action_time_emb = self.action_time_mlp_in(action_time_emb) - action_time_emb = F.silu(action_time_emb) # swish == silu - action_time_emb = self.action_time_mlp_out(action_time_emb) + + # Apply MLP layers + def mlp_func(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) # swish == silu + x = self.action_time_mlp_out(x) + return x + + action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) adarms_cond = None else: # time MLP (for adaRMS) - time_emb = self.time_mlp_in(time_emb) - time_emb = F.silu(time_emb) # swish == silu - time_emb = self.time_mlp_out(time_emb) - time_emb = F.silu(time_emb) + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) # swish == silu + x = self.time_mlp_out(x) + x = F.silu(x) + return x + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) action_time_emb = action_emb adarms_cond = time_emb @@ -230,8 +310,10 @@ def embed_suffix(self, state, noisy_actions, timestep): return embs, pad_masks, att_masks, adarms_cond - def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None) -> Tensor: + def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True) + if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -244,6 +326,9 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) @@ -251,22 +336,33 @@ def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, no att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 - # Add head dimension to attention mask: [B, seq_len, seq_len] -> [B, 1, seq_len, seq_len] - att_2d_masks_4d = att_2d_masks[:, None, :, :] - - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks_4d, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - fill_kv_cache=False, - adarms_cond=[None, adarms_cond] + # Prepare attention masks + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + # Apply gradient checkpointing if enabled + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond] + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) + suffix_out = suffix_out[:, -self.config.action_horizon :] suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) + # Apply gradient checkpointing to final action projection if enabled + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) losses = F.mse_loss(u_t, v_t, reduction="none") return losses @@ -279,20 +375,14 @@ def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tenso actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) noise = self.sample_noise(actions_shape, device) - images = list(observation.images.values()) - img_masks = list(observation.image_masks.values()) - lang_tokens = observation.tokenized_prompt - lang_masks = observation.tokenized_prompt_mask - state = observation.state + images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 # Compute image and language key value cache - # Add head dimension to attention mask: [B, seq_len, seq_len] -> [B, 1, seq_len, seq_len] - prefix_att_2d_masks_4d = prefix_att_2d_masks[:, None, :, :] - prefix_att_2d_masks_4d = torch.where(prefix_att_2d_masks_4d, 0.0, -2.3819763e38) + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" _, past_key_values = self.paligemma_with_expert.forward( @@ -347,9 +437,8 @@ def denoise_step( prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - # Add head dimension to attention mask: [B, seq_len, seq_len] -> [B, 1, seq_len, seq_len] - full_att_2d_masks_4d = full_att_2d_masks[:, None, :, :] - full_att_2d_masks_4d = torch.where(full_att_2d_masks_4d, 0.0, -2.3819763e38) + # Prepare attention masks + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" outputs_embeds, _ = self.paligemma_with_expert.forward( diff --git a/src/openpi/models_pytorch/preprocessing_pytorch.py b/src/openpi/models_pytorch/preprocessing_pytorch.py new file mode 100644 index 0000000..d67d24b --- /dev/null +++ b/src/openpi/models_pytorch/preprocessing_pytorch.py @@ -0,0 +1,171 @@ +import logging +from collections.abc import Sequence +import torch + +from openpi.shared import image_tools + +logger = logging.getLogger("openpi") + +# Constants moved from model.py +IMAGE_KEYS = ( + "base_0_rgb", + "left_wrist_0_rgb", + "right_wrist_0_rgb", +) + +IMAGE_RESOLUTION = (224, 224) + +def preprocess_observation_pytorch( + observation, + *, + train: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +): + """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. + + This function avoids complex type annotations that can cause torch.compile issues. + """ + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + for key in image_keys: + image = observation.images[key] + + # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats + # Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + image = image.permute(0, 2, 3, 1) + + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad_torch(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + image = image / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = image.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=image.device) + start_w = torch.randint(0, max_w + 1, (1,), device=image.device) + image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :] + + # Resize back to original size + image = torch.nn.functional.interpolate( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode='bilinear', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = angle * torch.pi / 180.0 + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=image.device) + grid_y = torch.linspace(-1, 1, height, device=image.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + image = torch.nn.functional.grid_sample( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 + image = image * brightness_factor + + # Random contrast + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 + mean = image.mean(dim=[1, 2, 3], keepdim=True) + image = (image - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 + gray = image.mean(dim=-1, keepdim=True) + image = gray + (image - gray) * saturation_factor + + # Clamp values to [0, 1] + image = torch.clamp(image, 0, 1) + + # Back to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) + else: + out_masks[key] = observation.image_masks[key] + + # Create a simple object with the required attributes instead of using the complex Observation class + class SimpleProcessedObservation: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + return SimpleProcessedObservation( + images=out_images, + image_masks=out_masks, + state=observation.state, + tokenized_prompt=observation.tokenized_prompt, + tokenized_prompt_mask=observation.tokenized_prompt_mask, + token_ar_mask=observation.token_ar_mask, + token_loss_mask=observation.token_loss_mask, + ) diff --git a/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py b/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py index 3ad389a..d052977 100644 --- a/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py @@ -457,6 +457,10 @@ def forward( adarms_cond: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: + """ + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -498,7 +502,9 @@ def forward( # embed positions hidden_states = inputs_embeds - hidden_states = hidden_states.to(torch.bfloat16) + # Convert to bfloat16 if the first layer uses bfloat16 + if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -609,6 +615,9 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + Example: ```python @@ -713,6 +722,9 @@ def forward( Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. """ transformer_outputs: BaseModelOutputWithPast = self.model( @@ -809,6 +821,9 @@ def forward( Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. """ outputs: BaseModelOutputWithPast = self.model( diff --git a/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py b/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py new file mode 100644 index 0000000..4bb3c96 --- /dev/null +++ b/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py @@ -0,0 +1,4 @@ +import transformers + +def check_whether_transformers_replace_is_installed_correctly(): + return transformers.__version__ == "4.53.2" \ No newline at end of file diff --git a/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py b/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py index d7dc902..3ea8acd 100644 --- a/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py @@ -773,7 +773,9 @@ def forward( ) hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - hidden_states = hidden_states.to(torch.bfloat16) + # Convert to bfloat16 if the encoder uses bfloat16 + if len(self.encoder.layers) > 0 and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, diff --git a/src/openpi/policies/policy_config.py b/src/openpi/policies/policy_config.py index f0601f4..3e8a2b7 100644 --- a/src/openpi/policies/policy_config.py +++ b/src/openpi/policies/policy_config.py @@ -52,6 +52,7 @@ def create_trained_policy( logging.info("Loading model...") if is_pytorch: model = train_config.model.load_pytorch(train_config, weight_path) + model.paligemma_with_expert.to_bfloat16_for_selected_params('bfloat16') else: model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16)) data_config = train_config.data.create(train_config.assets_dirs, train_config.model) diff --git a/src/openpi/shared/image_tools.py b/src/openpi/shared/image_tools.py index 4d63e1c..78f0e17 100644 --- a/src/openpi/shared/image_tools.py +++ b/src/openpi/shared/image_tools.py @@ -2,6 +2,8 @@ import jax import jax.numpy as jnp +import torch +import torch.nn.functional as F import openpi.shared.array_typing as at @@ -48,3 +50,80 @@ def resize_with_pad( if not has_batch_dim: padded_images = padded_images[0] return padded_images + + +def resize_with_pad_torch( + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + # Convert to channels-first for torch operations + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode='constant', + value=constant_value + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + if batch_size == 1 and images.shape[0] == 1: + padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added + + return padded_images \ No newline at end of file diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index a4caa3c..64635a7 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -6,7 +6,7 @@ import difflib import logging import pathlib -from typing import Any, Protocol, TypeAlias +from typing import Any, Protocol, TypeAlias, Literal import etils.epath as epath import flax.nnx as nnx @@ -379,7 +379,7 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig ) data_transforms = _transforms.Group( - inputs=[droid_policy.DroidInputs(action_dim=model_config.action_dim, model_type=model_config.model_type)], + inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)], outputs=[droid_policy.DroidOutputs()], ) @@ -461,6 +461,12 @@ class TrainConfig: # A weight loader can optionally load (possibly partial) weights from disk after the model is initialized. weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader) + # Optional path to a PyTorch checkpoint to load weights from. + pytorch_weight_path: str | None = None + + # Precision for PyTorch training. + pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16" + lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule) optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW) ema_decay: float | None = 0.99 @@ -724,7 +730,7 @@ def __post_init__(self) -> None: base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), - batch_size=256, + batch_size=16, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, @@ -734,8 +740,10 @@ def __post_init__(self) -> None: optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader=weight_loaders.CheckpointWeightLoader( - "gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/params" + "/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base/params" ), + pytorch_weight_path="/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch_float32", + pytorch_training_precision="float32", num_train_steps=30_000, ), # @@ -838,28 +846,74 @@ def __post_init__(self) -> None: num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), TrainConfig( - # This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset. - # Here, we use LeRobot data format (like for all other fine-tuning examples) - # To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py - name="pi05_droid_finetune", + # This config is for fine-tuning pi05 on the *full* DROID dataset. + # We use RLDS data loading to make training on this large dataset tractable. + # For fine-tuning on your own DROID dataset, see below. + name="pi05_full_droid_finetune", model=pi0_config.Pi0Config( pi05=True, - action_dim=32, # pi05 is trained with 32-dim actions + action_dim=32, action_horizon=16, ), - data=LeRobotDROIDDataConfig( - # Replace with your custom DROID LeRobot dataset repo id. - repo_id="your_hf_username/my_droid_dataset", - base_config=DataConfig(prompt_from_task=True), + data=RLDSDroidDataConfig( + repo_id="droid", + # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). + rlds_data_dir="/mnt/pi-data/kevin", + action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, assets=AssetsConfig( - # Important: reuse the original DROID norm stats during fine-tuning! - assets_dir="gs://openpi-assets-preview/checkpoints/pi05_droid/assets", + assets_dir="gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/assets/", asset_id="droid", ), ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets-preview/checkpoints/pi05_droid/params"), - num_train_steps=20_000, - batch_size=32, + weight_loader=weight_loaders.CheckpointWeightLoader( + "gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/params" + ), + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=1_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + num_train_steps=100_000, + batch_size=256, + log_interval=100, + save_interval=5000, + keep_period=10_000, + num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally + ), + TrainConfig( + # This config is for fine-tuning pi05 on the *full* DROID dataset. + # We use RLDS data loading to make training on this large dataset tractable. + # For fine-tuning on your own DROID dataset, see below. + name="pi05_full_droid_finetune_pytorch", + model=pi0_config.Pi0Config( + pi05=True, + action_dim=32, + action_horizon=16, + ), + data=RLDSDroidDataConfig( + repo_id="droid", + # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). + rlds_data_dir="/mnt/pi-data/kevin", + action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, + assets=AssetsConfig( + assets_dir="gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/assets/", + asset_id="droid", + ), + ), + weight_loader='/home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_base_pytorch2', + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=1_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + num_train_steps=100_000, + batch_size=256, + log_interval=100, + save_interval=5000, + keep_period=10_000, + num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), # # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment. @@ -932,4 +986,4 @@ def get_config(config_name: str) -> TrainConfig: closest_str = f" Did you mean '{closest[0]}'? " if closest else "" raise ValueError(f"Config '{config_name}' not found.{closest_str}") - return _CONFIGS_DICT[config_name] \ No newline at end of file + return _CONFIGS_DICT[config_name] diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index 1f97f27..6c1e916 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -1,4 +1,5 @@ from collections.abc import Iterator, Sequence +from typing import Literal import multiprocessing import os import typing @@ -7,6 +8,7 @@ import jax import jax.numpy as jnp import lerobot.common.datasets.lerobot_dataset as lerobot_dataset +import logging import numpy as np import torch @@ -225,9 +227,20 @@ def create_data_loader( shuffle: bool = False, num_batches: int | None = None, skip_norm_stats: bool = False, + framework: Literal["jax", "pytorch"], ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: - """Create a data loader for training.""" + """Create a data loader for training. + + Args: + config: The training configuration. + sharding: The sharding to use for the data loader (JAX only). + shuffle: Whether to shuffle the data. + num_batches: Determines the number of batches to return. + skip_norm_stats: Whether to skip data normalization. + framework: The framework to use ("jax" or "pytorch"). + """ data_config = config.data.create(config.assets_dirs, config.model) + logging.info(f"data_config: {data_config}") if data_config.rlds_data_dir is not None: return create_rlds_data_loader( @@ -238,6 +251,7 @@ def create_data_loader( shuffle=shuffle, num_batches=num_batches, skip_norm_stats=skip_norm_stats, + framework=framework, ) return create_torch_data_loader( data_config, @@ -250,6 +264,7 @@ def create_data_loader( num_workers=config.num_workers, seed=config.seed, skip_norm_stats=skip_norm_stats, + framework=framework, ) @@ -265,6 +280,7 @@ def create_torch_data_loader( num_batches: int | None = None, num_workers: int = 0, seed: int = 0, + framework: str = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. @@ -286,14 +302,35 @@ def create_torch_data_loader( dataset = create_torch_dataset(data_config, action_horizon, model_config) dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) + # Use TorchDataLoader for both frameworks + # For PyTorch DDP, create DistributedSampler and divide batch size by world size + # For JAX, divide by process count + sampler = None + if framework == "pytorch": + if torch.distributed.is_initialized(): + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=torch.distributed.get_world_size(), + rank=torch.distributed.get_rank(), + shuffle=shuffle, + drop_last=True, + ) + local_batch_size = batch_size // torch.distributed.get_world_size() + else: + local_batch_size = batch_size + else: + local_batch_size = batch_size // jax.process_count() + data_loader = TorchDataLoader( dataset, - local_batch_size=batch_size // jax.process_count(), - sharding=sharding, - shuffle=shuffle, + local_batch_size=local_batch_size, + sharding=None if framework == "pytorch" else sharding, + shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler + sampler=sampler, num_batches=num_batches, num_workers=num_workers, seed=seed, + framework=framework, ) return DataLoaderImpl(data_config, data_loader) @@ -308,6 +345,7 @@ def create_rlds_data_loader( skip_norm_stats: bool = False, shuffle: bool = False, num_batches: int | None = None, + framework: str = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create an RLDS data loader for training. @@ -325,19 +363,24 @@ def create_rlds_data_loader( number of batches in the dataset, the data loader will loop over the dataset. If not provided, will iterate over the dataset indefinitely. """ - dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) - dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) - - data_loader = RLDSDataLoader( - dataset, - sharding=sharding, - num_batches=num_batches, - ) + if framework == "pytorch": + raise NotImplementedError("PyTorch RLDS data loader is not supported yet") + else: + dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) + dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) + + data_loader = RLDSDataLoader( + dataset, + sharding=sharding, + num_batches=num_batches, + ) return DataLoaderImpl(data_config, data_loader) class TorchDataLoader: + """Torch data loader implementation.""" + def __init__( self, dataset, @@ -345,9 +388,11 @@ def __init__( *, sharding: jax.sharding.Sharding | None = None, shuffle: bool = False, + sampler: torch.utils.data.Sampler | None = None, num_batches: int | None = None, num_workers: int = 0, seed: int = 0, + framework: str = "jax", ): """Create a PyTorch data loader. @@ -370,14 +415,14 @@ def __init__( if len(dataset) < local_batch_size: raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") - if sharding is None: - # Use data parallel sharding by default. - sharding = jax.sharding.NamedSharding( + # Store sharding - None for PyTorch, JAX sharding for JAX + self._sharding = sharding + if sharding is None and framework == "jax": + # Use data parallel sharding by default for JAX only. + self._sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ("B",)), jax.sharding.PartitionSpec("B"), ) - - self._sharding = sharding self._num_batches = num_batches mp_context = None @@ -389,7 +434,8 @@ def __init__( self._data_loader = torch.utils.data.DataLoader( typing.cast(torch.utils.data.Dataset, dataset), batch_size=local_batch_size, - shuffle=shuffle, + shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler + sampler=sampler, num_workers=num_workers, multiprocessing_context=mp_context, persistent_workers=num_workers > 0, @@ -415,14 +461,18 @@ def __iter__(self): except StopIteration: break # We've exhausted the dataset. Create a new iterator and start over. num_items += 1 - yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) + # For JAX, convert to sharded arrays; for PyTorch, return torch tensors + if self._sharding is not None: + yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) + else: + yield jax.tree.map(torch.as_tensor, batch) def _collate_fn(items): """Collate the batch elements into batched numpy arrays.""" # Make sure to convert to numpy arrays before stacking since some of the incoming elements # may be JAX arrays. - return jax.tree.map(lambda *x: np.stack(np.asarray(x), axis=0), *items) + return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items) def _worker_init_fn(worker_id: int) -> None: diff --git a/uv.lock b/uv.lock index 1ad8e94..13bd8fc 100644 --- a/uv.lock +++ b/uv.lock @@ -1771,24 +1771,33 @@ wheels = [ [package.optional-dependencies] with-cuda = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", version = "12.9.0.13", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", version = "12.9.19", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "nvidia-cuda-nvcc-cu12" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-runtime-cu12", version = "12.9.37", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, { name = "nvidia-cudnn-cu12", version = "9.10.1.4", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cudnn-cu12", version = "9.10.2.21", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.3.83", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cufft-cu12", version = "11.4.0.6", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusolver-cu12", version = "11.7.3.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", version = "11.7.4.40", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", version = "12.5.9.5", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, { name = "nvidia-nccl-cu12", version = "2.26.5", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nccl-cu12", version = "2.27.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", version = "12.9.41", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] @@ -2608,22 +2617,31 @@ name = "nvidia-cublas-cu12" version = "12.6.4.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322 }, - { url = "https://files.pythonhosted.org/packages/97/0d/f1f0cadbf69d5b9ef2e4f744c9466cb0a850741d08350736dfdb4aa89569/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668", size = 390794615 }, { url = "https://files.pythonhosted.org/packages/84/f7/985e9bdbe3e0ac9298fcc8cfa51a392862a46a0ffaccbbd56939b62a9c83/nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8", size = 434535301 }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124 }, + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921 }, +] + [[package]] name = "nvidia-cublas-cu12" version = "12.9.0.13" @@ -2646,24 +2664,31 @@ name = "nvidia-cuda-cupti-cu12" version = "12.6.80" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/8b/2f6230cb715646c3a9425636e513227ce5c93c4d65823a734f4bb86d43c3/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc", size = 8236764 }, - { url = "https://files.pythonhosted.org/packages/25/0f/acb326ac8fd26e13c799e0b4f3b2751543e1834f04d62e729485872198d4/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4", size = 8236756 }, - { url = "https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132", size = 8917980 }, - { url = "https://files.pythonhosted.org/packages/a5/24/120ee57b218d9952c379d1e026c4479c9ece9997a4fb46303611ee48f038/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73", size = 8917972 }, { url = "https://files.pythonhosted.org/packages/1c/81/7796f096afaf726796b1b648f3bc80cafc61fe7f77f44a483c89e6c5ef34/nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a", size = 5724175 }, ] +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318 }, + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621 }, +] + [[package]] name = "nvidia-cuda-cupti-cu12" version = "12.9.19" @@ -2693,10 +2718,10 @@ wheels = [ [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.6.77" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380 }, + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029 }, ] [[package]] @@ -2704,24 +2729,31 @@ name = "nvidia-cuda-runtime-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/ea/590b2ac00d772a8abd1c387a92b46486d2679ca6622fd25c18ff76265663/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd", size = 908052 }, - { url = "https://files.pythonhosted.org/packages/b7/3d/159023799677126e20c8fd580cca09eeb28d5c5a624adc7f793b9aa8bbfa/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e", size = 908040 }, - { url = "https://files.pythonhosted.org/packages/e1/23/e717c5ac26d26cf39a27fbc076240fad2e3b817e5889d671b67f4f9f49c5/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7", size = 897690 }, - { url = "https://files.pythonhosted.org/packages/f0/62/65c05e161eeddbafeca24dc461f47de550d9fa8a7e04eb213e32b55cfd99/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8", size = 897678 }, { url = "https://files.pythonhosted.org/packages/fa/76/4c80fa138333cc975743fd0687a745fccb30d167f906f13c1c7f9a85e5ea/nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f", size = 891773 }, ] +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265 }, + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765 }, +] + [[package]] name = "nvidia-cuda-runtime-cu12" version = "12.9.37" @@ -2744,22 +2776,17 @@ name = "nvidia-cudnn-cu12" version = "9.5.1.17" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/99/93/a201a12d3ec1caa8c6ac34c1c2f9eeb696b886f0c36ff23c638b46603bd0/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def", size = 570523509 }, - { url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386 }, { url = "https://files.pythonhosted.org/packages/b6/b2/3f60d15f037fa5419d9d7f788b100ef33ea913ae5315c87ca6d6fa606c35/nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8", size = 565440743 }, ] @@ -2783,32 +2810,59 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/ec/79464a7371a028d1f443b8516b55cb2f70bb91bd3b2f2a831d707c003ccf/nvidia_cudnn_cu12-9.10.1.4-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:df73c4dab84df2c54f0a40e6427cde26e8d80feeffef02d749ee42d7da3c8204", size = 706752133 }, ] +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878 }, + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467 }, +] + [[package]] name = "nvidia-cufft-cu12" version = "11.3.0.4" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/37/c50d2b2f2c07e146776389e3080f4faf70bcc4fa6e19d65bb54ca174ebc3/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6", size = 200164144 }, - { url = "https://files.pythonhosted.org/packages/ce/f5/188566814b7339e893f8d210d3a5332352b1409815908dad6a363dcceac1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8510990de9f96c803a051822618d42bf6cb8f069ff3f48d93a8486efdacb48fb", size = 200164135 }, - { url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632 }, - { url = "https://files.pythonhosted.org/packages/60/de/99ec247a07ea40c969d904fc14f3a356b3e2a704121675b75c366b694ee1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca", size = 200221622 }, { url = "https://files.pythonhosted.org/packages/b4/38/36fd800cec8f6e89b7c1576edaaf8076e69ec631644cdbc1b5f2e2b5a9df/nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464", size = 199356881 }, ] +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211 }, + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695 }, +] + [[package]] name = "nvidia-cufft-cu12" version = "11.4.0.6" @@ -2831,19 +2885,18 @@ wheels = [ [[package]] name = "nvidia-cufile-cu12" -version = "1.11.1.6" +version = "1.13.1.3" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/66/cc9876340ac68ae71b15c743ddb13f8b30d5244af344ec8322b449e35426/nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159", size = 1142103 }, + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834 }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.7.77" +version = "10.3.9.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010 }, - { url = "https://files.pythonhosted.org/packages/4a/aa/2c7ff0b5ee02eaef890c0ce7d4f74bc30901871c5e45dee1ae6d0083cd80/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117", size = 56279000 }, + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976 }, ] [[package]] @@ -2851,29 +2904,41 @@ name = "nvidia-cusolver-cu12" version = "11.7.1.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/93/17/dbe1aa865e4fdc7b6d4d0dd308fdd5aaab60f939abfc0ea1954eac4fb113/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0", size = 157833628 }, - { url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790 }, - { url = "https://files.pythonhosted.org/packages/9f/81/baba53585da791d043c10084cf9553e074548408e04ae884cfe9193bd484/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6", size = 158229780 }, - { url = "https://files.pythonhosted.org/packages/7c/5f/07d0ba3b7f19be5a5ec32a8679fc9384cfd9fc6c869825e93be9f28d6690/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e", size = 157833630 }, { url = "https://files.pythonhosted.org/packages/d4/53/fff50a0808df7113d77e3bbc7c2b7eaed6f57d5eb80fbe93ead2aea1e09a/nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7", size = 149287877 }, ] +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841 }, + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905 }, +] + [[package]] name = "nvidia-cusolver-cu12" version = "11.7.4.40" @@ -2901,27 +2966,37 @@ name = "nvidia-cusparse-cu12" version = "12.5.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/eb/6681efd0aa7df96b4f8067b3ce7246833dd36830bb4cec8896182773db7d/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887", size = 216451147 }, - { url = "https://files.pythonhosted.org/packages/d3/56/3af21e43014eb40134dea004e8d0f1ef19d9596a39e4d497d5a7de01669f/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7aa32fa5470cf754f72d1116c7cbc300b4e638d3ae5304cfa4a638a5b87161b1", size = 216451135 }, - { url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367 }, - { url = "https://files.pythonhosted.org/packages/43/ac/64c4316ba163e8217a99680c7605f779accffc6a4bcd0c778c12948d3707/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f", size = 216561357 }, { url = "https://files.pythonhosted.org/packages/45/ef/876ad8e4260e1128e6d4aac803d9d51baf3791ebdb4a9b8d9b8db032b4b0/nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20", size = 213712630 }, ] +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129 }, + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466 }, +] + [[package]] name = "nvidia-cusparse-cu12" version = "12.5.9.5" @@ -2944,10 +3019,10 @@ wheels = [ [[package]] name = "nvidia-cusparselt-cu12" -version = "0.6.3" +version = "0.7.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796 }, + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691 }, ] [[package]] @@ -2964,20 +3039,13 @@ name = "nvidia-nccl-cu12" version = "2.26.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/5b/ca2f213f637305633814ae8c36b153220e40a07ea001966dcd87391f3acb/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522", size = 291671495 }, - { url = "https://files.pythonhosted.org/packages/67/ca/f42388aed0fddd64ade7493dbba36e1f534d4e6fdbdd355c6a90030ae028/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6", size = 201319755 }, -] [[package]] name = "nvidia-nccl-cu12" @@ -2996,27 +3064,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/fb/ec4ac065d9b0d56f72eaf1d9b0df601e33da28197b32ca351dc05b342611/nvidia_nccl_cu12-2.26.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea5ed3e053c735f16809bee7111deac62ac35b10128a8c102960a0462ce16cbe", size = 318069637 }, ] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/7b/8354b784cf73b0ba51e566b4baba3ddd44fe8288a3d39ef1e06cd5417226/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9ddf1a245abc36c550870f26d537a9b6087fb2e2e3d6e0ef03374c6fd19d984f", size = 322397768 }, + { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134 }, +] + [[package]] name = "nvidia-nvjitlink-cu12" version = "12.6.85" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux'", "python_full_version < '3.12' and sys_platform == 'emscripten'", ] wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971 }, - { url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338 }, { url = "https://files.pythonhosted.org/packages/89/76/93c1467b1387387440a4d25102d86b7794535449b689f8e2dc22c1c8ff7f/nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c", size = 161908572 }, ] +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836 }, + { url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204 }, +] + [[package]] name = "nvidia-nvjitlink-cu12" version = "12.9.41" @@ -3036,11 +3127,10 @@ wheels = [ [[package]] name = "nvidia-nvtx-cu12" -version = "12.6.77" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2", size = 89276 }, - { url = "https://files.pythonhosted.org/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265 }, + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954 }, ] [[package]] @@ -3180,7 +3270,7 @@ requires-dist = [ { name = "polars", specifier = ">=1.30.0" }, { name = "rich", specifier = ">=14.0.0" }, { name = "sentencepiece", specifier = ">=0.2.0" }, - { name = "torch", specifier = ">=2.7.0" }, + { name = "torch", specifier = ">=2.7.1" }, { name = "tqdm-loggable", specifier = ">=0.2" }, { name = "transformers", specifier = "==4.53.2" }, { name = "treescope", specifier = ">=0.1.7" }, @@ -4852,26 +4942,26 @@ wheels = [ [[package]] name = "torch" -version = "2.7.0" +version = "2.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", version = "12.8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", version = "12.8.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", version = "9.10.2.21", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", version = "11.3.3.83", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", version = "11.7.3.90", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", version = "12.5.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", version = "2.26.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", version = "2.27.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.8.93", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, @@ -4879,22 +4969,22 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/40/da/7378d16cc636697f2a94f791cb496939b60fb8580ddbbef22367db2c2274/torch-2.7.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2b7813e904757b125faf1a9a3154e1d50381d539ced34da1992f52440567c156", size = 99159397 }, - { url = "https://files.pythonhosted.org/packages/0e/6b/87fcddd34df9f53880fa1f0c23af7b6b96c935856473faf3914323588c40/torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fd5cfbb4c3bbadd57ad1b27d56a28008f8d8753733411a140fcfb84d7f933a25", size = 865183681 }, - { url = "https://files.pythonhosted.org/packages/13/85/6c1092d4b06c3db1ed23d4106488750917156af0b24ab0a2d9951830b0e9/torch-2.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:58df8d5c2eeb81305760282b5069ea4442791a6bbf0c74d9069b7b3304ff8a37", size = 212520100 }, - { url = "https://files.pythonhosted.org/packages/aa/3f/85b56f7e2abcfa558c5fbf7b11eb02d78a4a63e6aeee2bbae3bb552abea5/torch-2.7.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:0a8d43caa342b9986101ec5feb5bbf1d86570b5caa01e9cb426378311258fdde", size = 68569377 }, - { url = "https://files.pythonhosted.org/packages/aa/5e/ac759f4c0ab7c01feffa777bd68b43d2ac61560a9770eeac074b450f81d4/torch-2.7.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:36a6368c7ace41ad1c0f69f18056020b6a5ca47bedaca9a2f3b578f5a104c26c", size = 99013250 }, - { url = "https://files.pythonhosted.org/packages/9c/58/2d245b6f1ef61cf11dfc4aceeaacbb40fea706ccebac3f863890c720ab73/torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:15aab3e31c16feb12ae0a88dba3434a458874636f360c567caa6a91f6bfba481", size = 865042157 }, - { url = "https://files.pythonhosted.org/packages/44/80/b353c024e6b624cd9ce1d66dcb9d24e0294680f95b369f19280e241a0159/torch-2.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f56d4b2510934e072bab3ab8987e00e60e1262fb238176168f5e0c43a1320c6d", size = 212482262 }, - { url = "https://files.pythonhosted.org/packages/ee/8d/b2939e5254be932db1a34b2bd099070c509e8887e0c5a90c498a917e4032/torch-2.7.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:30b7688a87239a7de83f269333651d8e582afffce6f591fff08c046f7787296e", size = 68574294 }, - { url = "https://files.pythonhosted.org/packages/14/24/720ea9a66c29151b315ea6ba6f404650834af57a26b2a04af23ec246b2d5/torch-2.7.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:868ccdc11798535b5727509480cd1d86d74220cfdc42842c4617338c1109a205", size = 99015553 }, - { url = "https://files.pythonhosted.org/packages/4b/27/285a8cf12bd7cd71f9f211a968516b07dcffed3ef0be585c6e823675ab91/torch-2.7.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b52347118116cf3dff2ab5a3c3dd97c719eb924ac658ca2a7335652076df708", size = 865046389 }, - { url = "https://files.pythonhosted.org/packages/74/c8/2ab2b6eadc45554af8768ae99668c5a8a8552e2012c7238ded7e9e4395e1/torch-2.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:434cf3b378340efc87c758f250e884f34460624c0523fe5c9b518d205c91dd1b", size = 212490304 }, - { url = "https://files.pythonhosted.org/packages/28/fd/74ba6fde80e2b9eef4237fe668ffae302c76f0e4221759949a632ca13afa/torch-2.7.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:edad98dddd82220465b106506bb91ee5ce32bd075cddbcf2b443dfaa2cbd83bf", size = 68856166 }, - { url = "https://files.pythonhosted.org/packages/cb/b4/8df3f9fe6bdf59e56a0e538592c308d18638eb5f5dc4b08d02abb173c9f0/torch-2.7.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a885fc25afefb6e6eb18a7d1e8bfa01cc153e92271d980a49243b250d5ab6d9", size = 99091348 }, - { url = "https://files.pythonhosted.org/packages/9d/f5/0bd30e9da04c3036614aa1b935a9f7e505a9e4f1f731b15e165faf8a4c74/torch-2.7.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:176300ff5bc11a5f5b0784e40bde9e10a35c4ae9609beed96b4aeb46a27f5fae", size = 865104023 }, - { url = "https://files.pythonhosted.org/packages/d1/b7/2235d0c3012c596df1c8d39a3f4afc1ee1b6e318d469eda4c8bb68566448/torch-2.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d0ca446a93f474985d81dc866fcc8dccefb9460a29a456f79d99c29a78a66993", size = 212750916 }, - { url = "https://files.pythonhosted.org/packages/90/48/7e6477cf40d48cc0a61fa0d41ee9582b9a316b12772fcac17bc1a40178e7/torch-2.7.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:27f5007bdf45f7bb7af7f11d1828d5c2487e030690afb3d89a651fd7036a390e", size = 68575074 }, + { url = "https://files.pythonhosted.org/packages/8f/c4/3e7a3887eba14e815e614db70b3b529112d1513d9dae6f4d43e373360b7f/torch-2.8.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:220a06fd7af8b653c35d359dfe1aaf32f65aa85befa342629f716acb134b9710", size = 102073391 }, + { url = "https://files.pythonhosted.org/packages/5a/63/4fdc45a0304536e75a5e1b1bbfb1b56dd0e2743c48ee83ca729f7ce44162/torch-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c12fa219f51a933d5f80eeb3a7a5d0cbe9168c0a14bbb4055f1979431660879b", size = 888063640 }, + { url = "https://files.pythonhosted.org/packages/84/57/2f64161769610cf6b1c5ed782bd8a780e18a3c9d48931319f2887fa9d0b1/torch-2.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8c7ef765e27551b2fbfc0f41bcf270e1292d9bf79f8e0724848b1682be6e80aa", size = 241366752 }, + { url = "https://files.pythonhosted.org/packages/a4/5e/05a5c46085d9b97e928f3f037081d3d2b87fb4b4195030fc099aaec5effc/torch-2.8.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:5ae0524688fb6707c57a530c2325e13bb0090b745ba7b4a2cd6a3ce262572916", size = 73621174 }, + { url = "https://files.pythonhosted.org/packages/49/0c/2fd4df0d83a495bb5e54dca4474c4ec5f9c62db185421563deeb5dabf609/torch-2.8.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e2fab4153768d433f8ed9279c8133a114a034a61e77a3a104dcdf54388838705", size = 101906089 }, + { url = "https://files.pythonhosted.org/packages/99/a8/6acf48d48838fb8fe480597d98a0668c2beb02ee4755cc136de92a0a956f/torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2aca0939fb7e4d842561febbd4ffda67a8e958ff725c1c27e244e85e982173c", size = 887913624 }, + { url = "https://files.pythonhosted.org/packages/af/8a/5c87f08e3abd825c7dfecef5a0f1d9aa5df5dd0e3fd1fa2f490a8e512402/torch-2.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f4ac52f0130275d7517b03a33d2493bab3693c83dcfadf4f81688ea82147d2e", size = 241326087 }, + { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478 }, + { url = "https://files.pythonhosted.org/packages/10/4e/469ced5a0603245d6a19a556e9053300033f9c5baccf43a3d25ba73e189e/torch-2.8.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b2f96814e0345f5a5aed9bf9734efa913678ed19caf6dc2cddb7930672d6128", size = 101936856 }, + { url = "https://files.pythonhosted.org/packages/16/82/3948e54c01b2109238357c6f86242e6ecbf0c63a1af46906772902f82057/torch-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:65616ca8ec6f43245e1f5f296603e33923f4c30f93d65e103d9e50c25b35150b", size = 887922844 }, + { url = "https://files.pythonhosted.org/packages/e3/54/941ea0a860f2717d86a811adf0c2cd01b3983bdd460d0803053c4e0b8649/torch-2.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:659df54119ae03e83a800addc125856effda88b016dfc54d9f65215c3975be16", size = 241330968 }, + { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128 }, + { url = "https://files.pythonhosted.org/packages/15/0e/8a800e093b7f7430dbaefa80075aee9158ec22e4c4fc3c1a66e4fb96cb4f/torch-2.8.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:83c13411a26fac3d101fe8035a6b0476ae606deb8688e904e796a3534c197def", size = 102020139 }, + { url = "https://files.pythonhosted.org/packages/4a/15/5e488ca0bc6162c86a33b58642bc577c84ded17c7b72d97e49b5833e2d73/torch-2.8.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8f0a9d617a66509ded240add3754e462430a6c1fc5589f86c17b433dd808f97a", size = 887990692 }, + { url = "https://files.pythonhosted.org/packages/b4/a8/6a04e4b54472fc5dba7ca2341ab219e529f3c07b6941059fbf18dccac31f/torch-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a7242b86f42be98ac674b88a4988643b9bc6145437ec8f048fea23f72feb5eca", size = 241603453 }, + { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395 }, ] [[package]] @@ -4914,7 +5004,7 @@ wheels = [ [[package]] name = "torchvision" -version = "0.22.0" +version = "0.23.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, @@ -4922,22 +5012,22 @@ dependencies = [ { name = "torch" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/43/28bc858b022f6337326d75f4027d2073aad5432328f01ee1236d847f1b82/torchvision-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:191ea28321fc262d8aa1a7fe79c41ff2848864bf382f9f6ea45c41dde8313792", size = 1947828 }, - { url = "https://files.pythonhosted.org/packages/7e/71/ce9a303b94e64fe25d534593522ffc76848c4e64c11e4cbe9f6b8d537210/torchvision-0.22.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6c5620e10ffe388eb6f4744962106ed7cf1508d26e6fdfa0c10522d3249aea24", size = 2514016 }, - { url = "https://files.pythonhosted.org/packages/09/42/6908bff012a1dcc4fc515e52339652d7f488e208986542765c02ea775c2f/torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ce292701c77c64dd3935e3e31c722c3b8b176a75f76dc09b804342efc1db5494", size = 7447546 }, - { url = "https://files.pythonhosted.org/packages/e4/cf/8f9305cc0ea26badbbb3558ecae54c04a245429f03168f7fad502f8a5b25/torchvision-0.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:e4017b5685dbab4250df58084f07d95e677b2f3ed6c2e507a1afb8eb23b580ca", size = 1716472 }, - { url = "https://files.pythonhosted.org/packages/cb/ea/887d1d61cf4431a46280972de665f350af1898ce5006cd046326e5d0a2f2/torchvision-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31c3165418fe21c3d81fe3459e51077c2f948801b8933ed18169f54652796a0f", size = 1947826 }, - { url = "https://files.pythonhosted.org/packages/72/ef/21f8b6122e13ae045b8e49658029c695fd774cd21083b3fa5c3f9c5d3e35/torchvision-0.22.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8f116bc82e0c076e70ba7776e611ed392b9666aa443662e687808b08993d26af", size = 2514571 }, - { url = "https://files.pythonhosted.org/packages/7c/48/5f7617f6c60d135f86277c53f9d5682dfa4e66f4697f505f1530e8b69fb1/torchvision-0.22.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ce4dc334ebd508de2c534817c9388e928bc2500cf981906ae8d6e2ca3bf4727a", size = 7446522 }, - { url = "https://files.pythonhosted.org/packages/99/94/a015e93955f5d3a68689cc7c385a3cfcd2d62b84655d18b61f32fb04eb67/torchvision-0.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:24b8c9255c209ca419cc7174906da2791c8b557b75c23496663ec7d73b55bebf", size = 1716664 }, - { url = "https://files.pythonhosted.org/packages/e1/2a/9b34685599dcb341d12fc2730055155623db7a619d2415a8d31f17050952/torchvision-0.22.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ece17995857dd328485c9c027c0b20ffc52db232e30c84ff6c95ab77201112c5", size = 1947823 }, - { url = "https://files.pythonhosted.org/packages/77/77/88f64879483d66daf84f1d1c4d5c31ebb08e640411139042a258d5f7dbfe/torchvision-0.22.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:471c6dd75bb984c6ebe4f60322894a290bf3d4b195e769d80754f3689cd7f238", size = 2471592 }, - { url = "https://files.pythonhosted.org/packages/f7/82/2f813eaae7c1fae1f9d9e7829578f5a91f39ef48d6c1c588a8900533dd3d/torchvision-0.22.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2b839ac0610a38f56bef115ee5b9eaca5f9c2da3c3569a68cc62dbcc179c157f", size = 7446333 }, - { url = "https://files.pythonhosted.org/packages/58/19/ca7a4f8907a56351dfe6ae0a708f4e6b3569b5c61d282e3e7f61cf42a4ce/torchvision-0.22.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ada1c08b2f761443cd65b7c7b4aec9e2fc28f75b0d4e1b1ebc9d3953ebccc4d", size = 1716693 }, - { url = "https://files.pythonhosted.org/packages/6f/a7/f43e9c8d13118b4ffbaebea664c9338ab20fa115a908125afd2238ff16e7/torchvision-0.22.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:cdc96daa4658b47ce9384154c86ed1e70cba9d972a19f5de6e33f8f94a626790", size = 2137621 }, - { url = "https://files.pythonhosted.org/packages/6a/9a/2b59f5758ba7e3f23bc84e16947493bbce97392ec6d18efba7bdf0a3b10e/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:753d3c84eeadd5979a33b3b73a25ecd0aa4af44d6b45ed2c70d44f5e0ac68312", size = 2476555 }, - { url = "https://files.pythonhosted.org/packages/7d/40/a7bc2ab9b1e56d10a7fd9ae83191bb425fa308caa23d148f1c568006e02c/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b30e3ed29e4a61f7499bca50f57d8ebd23dfc52b14608efa17a534a55ee59a03", size = 7617924 }, - { url = "https://files.pythonhosted.org/packages/c1/7b/30d423bdb2546250d719d7821aaf9058cc093d165565b245b159c788a9dd/torchvision-0.22.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e5d680162694fac4c8a374954e261ddfb4eb0ce103287b0f693e4e9c579ef957", size = 1638621 }, + { url = "https://files.pythonhosted.org/packages/f0/d7/15d3d7bd8d0239211b21673d1bac7bc345a4ad904a8e25bb3fd8a9cf1fbc/torchvision-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:49aa20e21f0c2bd458c71d7b449776cbd5f16693dd5807195a820612b8a229b7", size = 1856884 }, + { url = "https://files.pythonhosted.org/packages/dd/14/7b44fe766b7d11e064c539d92a172fa9689a53b69029e24f2f1f51e7dc56/torchvision-0.23.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:01dc33ee24c79148aee7cdbcf34ae8a3c9da1674a591e781577b716d233b1fa6", size = 2395543 }, + { url = "https://files.pythonhosted.org/packages/79/9c/fcb09aff941c8147d9e6aa6c8f67412a05622b0c750bcf796be4c85a58d4/torchvision-0.23.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:35c27941831b653f5101edfe62c03d196c13f32139310519e8228f35eae0e96a", size = 8628388 }, + { url = "https://files.pythonhosted.org/packages/93/40/3415d890eb357b25a8e0a215d32365a88ecc75a283f75c4e919024b22d97/torchvision-0.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:09bfde260e7963a15b80c9e442faa9f021c7e7f877ac0a36ca6561b367185013", size = 1600741 }, + { url = "https://files.pythonhosted.org/packages/df/1d/0ea0b34bde92a86d42620f29baa6dcbb5c2fc85990316df5cb8f7abb8ea2/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0e2c04a91403e8dd3af9756c6a024a1d9c0ed9c0d592a8314ded8f4fe30d440", size = 1856885 }, + { url = "https://files.pythonhosted.org/packages/e2/00/2f6454decc0cd67158c7890364e446aad4b91797087a57a78e72e1a8f8bc/torchvision-0.23.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6dd7c4d329a0e03157803031bc856220c6155ef08c26d4f5bbac938acecf0948", size = 2396614 }, + { url = "https://files.pythonhosted.org/packages/e4/b5/3e580dcbc16f39a324f3dd71b90edbf02a42548ad44d2b4893cc92b1194b/torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4e7d31c43bc7cbecbb1a5652ac0106b436aa66e26437585fc2c4b2cf04d6014c", size = 8627108 }, + { url = "https://files.pythonhosted.org/packages/82/c1/c2fe6d61e110a8d0de2f94276899a2324a8f1e6aee559eb6b4629ab27466/torchvision-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:a2e45272abe7b8bf0d06c405e78521b5757be1bd0ed7e5cd78120f7fdd4cbf35", size = 1600723 }, + { url = "https://files.pythonhosted.org/packages/91/37/45a5b9407a7900f71d61b2b2f62db4b7c632debca397f205fdcacb502780/torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1c37e325e09a184b730c3ef51424f383ec5745378dc0eca244520aca29722600", size = 1856886 }, + { url = "https://files.pythonhosted.org/packages/ac/da/a06c60fc84fc849377cf035d3b3e9a1c896d52dbad493b963c0f1cdd74d0/torchvision-0.23.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2f7fd6c15f3697e80627b77934f77705f3bc0e98278b989b2655de01f6903e1d", size = 2353112 }, + { url = "https://files.pythonhosted.org/packages/a0/27/5ce65ba5c9d3b7d2ccdd79892ab86a2f87ac2ca6638f04bb0280321f1a9c/torchvision-0.23.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a76fafe113b2977be3a21bf78f115438c1f88631d7a87203acb3dd6ae55889e6", size = 8627658 }, + { url = "https://files.pythonhosted.org/packages/1f/e4/028a27b60aa578a2fa99d9d7334ff1871bb17008693ea055a2fdee96da0d/torchvision-0.23.0-cp313-cp313-win_amd64.whl", hash = "sha256:07d069cb29691ff566e3b7f11f20d91044f079e1dbdc9d72e0655899a9b06938", size = 1600749 }, + { url = "https://files.pythonhosted.org/packages/05/35/72f91ad9ac7c19a849dedf083d347dc1123f0adeb401f53974f84f1d04c8/torchvision-0.23.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2df618e1143805a7673aaf82cb5720dd9112d4e771983156aaf2ffff692eebf9", size = 2047192 }, + { url = "https://files.pythonhosted.org/packages/1d/9d/406cea60a9eb9882145bcd62a184ee61e823e8e1d550cdc3c3ea866a9445/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a3299d2b1d5a7aed2d3b6ffb69c672ca8830671967eb1cee1497bacd82fe47b", size = 2359295 }, + { url = "https://files.pythonhosted.org/packages/2b/f4/34662f71a70fa1e59de99772142f22257ca750de05ccb400b8d2e3809c1d/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:76bc4c0b63d5114aa81281390f8472a12a6a35ce9906e67ea6044e5af4cab60c", size = 8800474 }, + { url = "https://files.pythonhosted.org/packages/6e/f5/b5a2d841a8d228b5dbda6d524704408e19e7ca6b7bb0f24490e081da1fa1/torchvision-0.23.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b9e2dabf0da9c8aa9ea241afb63a8f3e98489e706b22ac3f30416a1be377153b", size = 1527667 }, ] [[package]] @@ -5039,16 +5129,16 @@ wheels = [ [[package]] name = "triton" -version = "3.3.0" +version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "setuptools", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/c5/4874a81131cc9e934d88377fbc9d24319ae1fb540f3333b4e9c696ebc607/triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3161a2bf073d6b22c4e2f33f951f3e5e3001462b2570e6df9cd57565bdec2984", size = 156528461 }, - { url = "https://files.pythonhosted.org/packages/11/53/ce18470914ab6cfbec9384ee565d23c4d1c55f0548160b1c7b33000b11fd/triton-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b68c778f6c4218403a6bd01be7484f6dc9e20fe2083d22dd8aef33e3b87a10a3", size = 156504509 }, - { url = "https://files.pythonhosted.org/packages/7d/74/4bf2702b65e93accaa20397b74da46fb7a0356452c1bb94dbabaf0582930/triton-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47bc87ad66fa4ef17968299acacecaab71ce40a238890acc6ad197c3abe2b8f1", size = 156516468 }, - { url = "https://files.pythonhosted.org/packages/0a/93/f28a696fa750b9b608baa236f8225dd3290e5aff27433b06143adc025961/triton-3.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce4700fc14032af1e049005ae94ba908e71cd6c2df682239aed08e49bc71b742", size = 156580729 }, + { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138 }, + { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068 }, + { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223 }, + { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780 }, ] [[package]]