Skip to content

Finetuning2#3

Open
yaolug wants to merge 33 commits into
pi05-pytorchfrom
finetuning2
Open

Finetuning2#3
yaolug wants to merge 33 commits into
pi05-pytorchfrom
finetuning2

Conversation

@yaolug

@yaolug yaolug commented Aug 22, 2025

Copy link
Copy Markdown
Owner

No description provided.

@yaolug yaolug force-pushed the pi05-pytorch branch 2 times, most recently from f83909e to e5d936c Compare August 26, 2025 07:03
Comment thread scripts/train_pytorch.py Outdated
# Reuse existing dataset + transforms pipeline
data_conf = config.data.create(config.assets_dirs, config.model)
dataset = _data.create_torch_dataset(data_conf, config.model.action_horizon, config.model)
print(f"data_conf: {data_conf}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't use print

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread scripts/train_pytorch.py Outdated

# Parse additional command line arguments for memory optimization
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--resume", action="store_true", default=False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the config already has resume/overwrite flags

Comment thread scripts/train_pytorch.py Outdated
return result


def _tree_map_multi(func, batch_list):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly think it's easier to just use JAX here as well lol

jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *batch_list)

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread scripts/train_pytorch.py Outdated
# Use full batch size since we removed gradient accumulation
effective_batch_size = config.batch_size // (torch.distributed.get_world_size() if use_ddp else 1)

loader = torch.utils.data.DataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_to_numpy)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe dumb question, but why not use the existing openpi dataloader? we can make the JAX-specific things optional (namely, the jax.make_array_from_process_local_data), and add the necessary PyTorch specific things (e.g., custom sampler). other than that, the implementations look fairly similar, and I think it would make things easier to maintain going forward if they were shared.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread src/openpi/models/model.py Outdated
)


def preprocess_observation_pytorch(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not put this in the models_pytorch directory somewhere?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread scripts/train_single_example.py Outdated
return False, None


def compare_losses(pytorch_loss, jax_loss):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this is all AI-generated but this is a crazy amount of unnecessary code... this whole function could be replaced with np.testing.assert_allclose. I'm fine with having this file but maybe not in the top-level scripts/ directory, would prefer if it was in examples/compare_jax_pytorch.py or something like that.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will not release this file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants