I notice you have a stable branch for multi-gpu testing. I was just wondering if torch2jax does actually work out of the box when using what I believe is now the standard JAX multi-gpu paradigm of sharding i.e.
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
https://jax.readthedocs.io/en/latest/notebooks/shard_map.html
I notice you have a stable branch for multi-gpu testing. I was just wondering if torch2jax does actually work out of the box when using what I believe is now the standard JAX multi-gpu paradigm of sharding i.e.
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
https://jax.readthedocs.io/en/latest/notebooks/shard_map.html