Skip to content
Discussion options

You must be logged in to vote

do you think there is a problem with doing nnx.Variable(Node())

@rademacher-p originally Optimizer was written like this as well e.g. self.opt_state = OptState(opt_state). We ran into an edge case eventually (don't exactly remember) so we switched to wrapping the leaves instead but you can try. It does work:

@jax.tree_util.register_dataclass
@dataclasses.dataclass
class Node:
    var: tp.Any

class A(nnx.Pytree):
    def __init__(self):
        self.foo = nnx.Variable(Node(jnp.int32(0)))

a = A()
new_a = jax.tree.map(lambda x: x + 1, a)
nnx.update(a, nnx.state(new_a))

print(a)

Replies: 2 comments 16 replies

Comment options

You must be logged in to vote
5 replies
@rademacher-p
Comment options

@cgarciae
Comment options

@rademacher-p
Comment options

@DBraun
Comment options

@DBraun
Comment options

Comment options

You must be logged in to vote
11 replies
@cgarciae
Comment options

@cgarciae
Comment options

@rademacher-p
Comment options

@cgarciae
Comment options

Answer selected by rademacher-p
@rademacher-p
Comment options

@cgarciae
Comment options

@cgarciae
Comment options

@rademacher-p
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants