Skip to content

jax 0.7.0 Breaks Torch2Jax #31

Description

@adam-hartshorne

They are minor fixes, but thought I would point them out.

File "/media/adam/shared_folder/PycharmProjects/test/chamfer_distance.py", line 5, in
from torch2jax import torch2jax
File "/home/adam/anaconda3/envs/jax_latest/lib/python3.12/site-packages/torch2jax/init.py", line 1, in
from .api import torch2jax, dtype_t2j # noqa: F401
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/adam/anaconda3/envs/jax_latest/lib/python3.12/site-packages/torch2jax/api.py", line 12, in
from jax.util import safe_zip
ImportError: cannot import name 'safe_zip' from 'jax.util' (/home/adam/anaconda3/envs/jax_latest/lib/python3.12/site-packages/jax/util.py)


Traceback (most recent call last):
File "/media/adam/shared_folder/PycharmProjects/test/model/losses/chamfer_distance.py", line 5, in
from torch2jax import torch2jax
File "/home/adam/anaconda3/envs/jax_latest/lib/python3.12/site-packages/torch2jax/init.py", line 3, in
from .dlpack_passing import j2t, t2j, tree_j2t, tree_t2j # noqa: F401
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/adam/anaconda3/envs/jax_latest/lib/python3.12/site-packages/torch2jax/dlpack_passing.py", line 15, in
JAXDevice = jax.lib.xla_extension.Device
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/adam/anaconda3/envs/jax_latest/lib/python3.12/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(message)
AttributeError: jax.lib.xla_extension.Device was deprecated in JAX v0.6.0 and removed in JAX v0.7.0; use jax.Device instead.


I don't know if the changes to autodiff are useful in general for torch2jax
https://github.com/jax-ml/jax/releases/tag/jax-v0.7.0
https://docs.jax.dev/en/latest/direct_linearize_migration.html

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions