Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
44bc4b2
support finetuning
yaolug Aug 22, 2025
468e5f2
add preprocess
yaolug Aug 22, 2025
eb6ebb1
fixes
yaolug Aug 22, 2025
e0d76fb
fix resume
yaolug Aug 22, 2025
642bb4f
add gradient checkpointing
yaolug Aug 22, 2025
3bd6bde
fix gradient checkpointing
yaolug Aug 22, 2025
5789f0d
batch size 16 working
yaolug Aug 25, 2025
bcce0ed
bs 1024 (2 nodes) working
yaolug Aug 26, 2025
f554bc8
support finetuning
yaolug Aug 22, 2025
b9b7f6a
add preprocess
yaolug Aug 22, 2025
d2aba1c
fixes
yaolug Aug 22, 2025
6fa3f8d
fix resume
yaolug Aug 22, 2025
3c9084a
add gradient checkpointing
yaolug Aug 22, 2025
e09ee98
fix gradient checkpointing
yaolug Aug 22, 2025
615149e
batch size 16 working
yaolug Aug 25, 2025
4fc7766
bs 1024 (2 nodes) working
yaolug Aug 26, 2025
25d8af2
fix merge error
yaolug Aug 26, 2025
44dba74
clean up pytorch finetuning
yaolug Aug 26, 2025
b93f363
try float32
yaolug Aug 27, 2025
6775bdd
reuse jax dataloader
yaolug Aug 27, 2025
5f3aaba
further simplify dataloader
yaolug Aug 27, 2025
bde41be
fix batch size
yaolug Aug 27, 2025
974c45c
fix batch size again
yaolug Aug 27, 2025
a644c95
further clean up
yaolug Aug 27, 2025
db8b5f7
add missing file
yaolug Aug 27, 2025
c51b061
add documentation, check transfoemres_replace
yaolug Aug 28, 2025
ed03aee
change pi0 back
yaolug Aug 28, 2025
b42d3ac
minor changes
yaolug Aug 28, 2025
a30c8d6
further cleanup
yaolug Aug 28, 2025
babea0e
further cleanup
yaolug Aug 28, 2025
f4a5847
fix a bug
yaolug Aug 28, 2025
b8597b6
add check.py
yaolug Aug 28, 2025
cde7856
fix float32
yaolug Aug 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ We provide more examples for how to fine-tune and run inference with our models
openpi now provides PyTorch implementations of π₀ and π₀.₅ models alongside the original JAX versions (π₀-FAST is not currently supported in PyTorch). The PyTorch models offer greater deployment flexibility and seamless integration with PyTorch-based ML stacks, while maintaining feature parity with their JAX counterparts.

### Setup
1. Upgrade the transformers library to 4.53.2
1. Upgrade the transformers library to 4.53.2 and torch to 2.7.1
- The required version is already specified in pyproject.toml
- If you set up your environment previously, reinstall it to ensure you have transformers 4.53.2
- You can verify the version with `pip show transformers`
- If you set up your environment previously, reinstall it to ensure you have transformers 4.53.2 and torch 2.7.1
- You can verify the version with `uv pip show transformers` and `uv pip show torch`

2. Apply the transformers library patches
```bash
Expand Down Expand Up @@ -235,6 +235,73 @@ uv run scripts/serve_policy.py policy:checkpoint \
--policy.dir=/path/to/converted/pytorch/checkpoint
```

### Finetuning with PyTorch

To finetune a model in PyTorch:

1. Convert the JAX base model to PyTorch format:
```bash
python examples/convert_jax_model_to_pytorch.py \
--checkpoint_dir /path/to/jax/base/model \
--output_path /path/to/pytorch/base/model
```

2. Specify the converted PyTorch model path in your config using `pytorch_weight_path`

3. Launch training using one of these modes:

```bash
# Single GPU training:
uv run scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>

# Example:
uv run scripts/train_pytorch.py debug --exp_name pytorch_test
uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume # Resume from latest checkpoint

# Multi-GPU training (single node):
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>

# Example:
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume

# Multi-Node Training:
torchrun \
--nnodes=<num_nodes> \
--nproc_per_node=<gpus_per_node> \
--node_rank=<rank_of_node> \
--master_addr=<master_ip> \
--master_port=<port> \
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
```

### Precision Settings

JAX and PyTorch implementations handle precision as follows:

**JAX:**
1. Inference: weights and activations are bfloat16 except for selected layers in float32
2. Training: weights and gradients are float32, activations are bfloat16

**PyTorch:**
1. Inference: matches JAX - weights and activations are bfloat16 except for selected layers in float32
2. Training: supports either all float32 or the same mixed precision as inference (default) You can change it by setting `pytorch_training_precision` in the config. Per GPU batch size can be 64 with mixed precision but 16 with float32 on a 80GB memory A100 or H100. Further optimizations to reduce memory consumption can be done in the future.

### Validation Results

We have validated the PyTorch implementation against JAX:

1. Converting a JAX-finetuned pi05_libero model (bfloat16 precision) to PyTorch and evaluate:
- JAX checkpoint: 92.0% on Libero-10
- Converted PyTorch checkpoint: 92.2% on Libero-10

2. Finetuning pi05_libero from pi05_base model in PyTorch:
- Using float32: (convert JAX ckpt to pytorch in float32) Loss curves match JAX throughout training
- Using bfloat16: (convert JAX ckpt to pytorch in bfloat16) Loss curves match after 10k steps (in 30k total steps)
- Final performance: 91.4% on Libero-10

3. We comapred inference speed between JAX and PyTorch and they are comparable.

## Troubleshooting

We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines).
Expand Down
36 changes: 25 additions & 11 deletions examples/convert_jax_model_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

Usage:
# Just inspect keys:
python convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only

# Convert to PyTorch:
python convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output

Example:
# pi0_droid
Expand Down Expand Up @@ -336,8 +338,12 @@ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str |
}
restore_dtype = dtype_map.get(restore_precision) if restore_precision else None

# Use CPU sharding to avoid GPU memory issues during checkpoint loading
cpu_device = jax.devices('cpu')[0]
cpu_sharding = jax.sharding.SingleDeviceSharding(cpu_device)

# Use repository restore utility to load a pure dict of params (value suffix removed)
params = openpi.models.model.restore_params(params_dir, restore_type=jax.Array, dtype=restore_dtype)
params = openpi.models.model.restore_params(params_dir, restore_type=jax.Array, dtype=restore_dtype, sharding=cpu_sharding)

# get params for PaliGemma
pali_params = params["PaliGemma"]
Expand Down Expand Up @@ -382,7 +388,8 @@ def load_jax_model_and_print_keys(checkpoint_dir: str):
return

item = {params_name: metadata[params_name]}
device = jax.local_devices()[0]
# Use CPU device to avoid GPU memory issues
device = jax.devices('cpu')[0]
sharding = jax.sharding.SingleDeviceSharding(device)

restored = checkpointer.restore(
Expand Down Expand Up @@ -536,6 +543,12 @@ def __init__(self):
action_horizon=10,
pi05=True,
)
elif "pi05_base" in checkpoint_dir:
pi0_config = openpi.models.pi0_config.Pi0Config(
action_dim=32,
action_horizon=50,
pi05=True,
)
else:
pi0_config = openpi.models.pi0_config.Pi0Config(
action_dim=8,
Expand All @@ -549,13 +562,14 @@ def __init__(self):
all_params = {**paligemma_params, **gemma_params, **projection_params}

# Load state dict
try:
pi0_model.load_state_dict(all_params)
except Exception as e:
print(f"Warning: Could not load all parameters: {e}")
print("Continuing with partial load...")
pi0_model.load_state_dict(all_params, strict=False)

pi0_model = pi0_model.to(torch.bfloat16)
if precision == "float32":
pi0_model = pi0_model.to(torch.float32)
elif precision == "bfloat16":
pi0_model = pi0_model.to(torch.bfloat16)
else:
raise ValueError(f"Invalid precision: {precision}")

# Save the converted model using safetensors
os.makedirs(output_path, exist_ok=True)
Expand Down Expand Up @@ -602,7 +616,7 @@ def main():
parser.add_argument(
"--precision",
choices=["float32", "bfloat16", "float16"],
default="float32",
default="bfloat16",
type=str,
help="Precision for model conversion"
)
Expand Down
6 changes: 3 additions & 3 deletions examples/inference_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
This demonstrates the basic usage patterns for both implementations.

pi0_droid
python examples/inference_example.py --model_name pi0_droid --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_droid --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
python examples/inference_example.py --model_name pi0_droid --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch

pi0_aloha_sim
python examples/inference_example.py --model_name pi0_aloha_sim --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
python examples/inference_example.py --model_name pi0_aloha_sim --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch

pi05_droid
python examples/inference_example.py --model_name pi05_droid --jax_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid --pytorch_checkpoint_dir /home/jasonlu/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch
python examples/inference_example.py --model_name pi05_droid --jax_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid --pytorch_checkpoint_dir /home/$USER/.cache/openpi/openpi-assets-preview/checkpoints/pi05_droid_pytorch

"""
import argparse
Expand Down
Loading
Loading