Finetuning2#3
Conversation
f83909e to
e5d936c
Compare
| # Reuse existing dataset + transforms pipeline | ||
| data_conf = config.data.create(config.assets_dirs, config.model) | ||
| dataset = _data.create_torch_dataset(data_conf, config.model.action_horizon, config.model) | ||
| print(f"data_conf: {data_conf}") |
|
|
||
| # Parse additional command line arguments for memory optimization | ||
| parser = argparse.ArgumentParser(add_help=False) | ||
| parser.add_argument("--resume", action="store_true", default=False, |
There was a problem hiding this comment.
the config already has resume/overwrite flags
| return result | ||
|
|
||
|
|
||
| def _tree_map_multi(func, batch_list): |
There was a problem hiding this comment.
I honestly think it's easier to just use JAX here as well lol
jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *batch_list)
| # Use full batch size since we removed gradient accumulation | ||
| effective_batch_size = config.batch_size // (torch.distributed.get_world_size() if use_ddp else 1) | ||
|
|
||
| loader = torch.utils.data.DataLoader(dataset, batch_size=effective_batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_to_numpy) |
There was a problem hiding this comment.
maybe dumb question, but why not use the existing openpi dataloader? we can make the JAX-specific things optional (namely, the jax.make_array_from_process_local_data), and add the necessary PyTorch specific things (e.g., custom sampler). other than that, the implementations look fairly similar, and I think it would make things easier to maintain going forward if they were shared.
| ) | ||
|
|
||
|
|
||
| def preprocess_observation_pytorch( |
There was a problem hiding this comment.
why not put this in the models_pytorch directory somewhere?
| return False, None | ||
|
|
||
|
|
||
| def compare_losses(pytorch_loss, jax_loss): |
There was a problem hiding this comment.
I realize this is all AI-generated but this is a crazy amount of unnecessary code... this whole function could be replaced with np.testing.assert_allclose. I'm fine with having this file but maybe not in the top-level scripts/ directory, would prefer if it was in examples/compare_jax_pytorch.py or something like that.
There was a problem hiding this comment.
Will not release this file
No description provided.