-
|
The API docs for apply() and init() state that the mutable parameter controls which variable collections are mutable, but it doesn't state exactly how the behavior changes when a variable collection is considered to be 'mutable'. I suspect it just means that all variables in those collections are returned by the apply or init function (even if those variables haven't actually been changed), and that variables in other collections are not returned (even if they otherwise would have been changed). If so, please just update the docs to state this explicitly. But maybe there are behavior changes too? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
|
Been exploring this structure to understand this machanic better: import jax.numpy as jnp
import jax
import flax.linen as nn
class Block(nn.Module):
@nn.compact
def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray:
x = nn.Dense(16)(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.Dropout(0.2, deterministic=not training)(x)
x = nn.relu(x)
return xOnce initialized like this: key = jax.random.PRNGKey(0)
x = jnp.ones((10, 2))
module = Block()
init_rngs = dict(zip(["params", "dropout"], jax.random.split(key, 2)))
variables = module.init(init_rngs, x, True)There are a couple of scenarios I tried out: # 1. Apply specifying mutable batch_stats
key, dropout_key = jax.random.split(key, 2)
y, variable_updates = module.apply(variables, x, True, rngs={"dropout": dropout_key}, mutable=["batch_stats"])
variables = variables.copy(variable_updates)
# 2. Apply setting mutable=True
key, dropout_key = jax.random.split(key, 2)
y, variable_updates = module.apply(variables, x, True, rngs={"dropout": dropout_key}, mutable=True)
variables = variables.copy(variable_updates)
# 3. Apply with no mutable state: yields an error
key, dropout_key = jax.random.split(key, 2)
y, variable_updates = module.apply(variables, x, True, rngs={"dropout": dropout_key}, mutable=[])
variables = variables.copy(variable_updates)
# 4. Apply with training=False, no mutable state nor rngs
y = module.apply(variables, x, False)To answer the question, mutable doesn't filter the Documentation for
However maybe a simple example might be more useful. |
Beta Was this translation helpful? Give feedback.
-
|
@billmark I moved your issue to our Github Discussions page, since @cgarciae gave an answer, so this can immediately serve as the documentation! |
Beta Was this translation helpful? Give feedback.
-
|
A while ago I opened #4779 and PR #4783 that adds the different usages of |
Beta Was this translation helpful? Give feedback.
Been exploring this structure to understand this machanic better:
Once initialized like this:
There are a couple of scenarios I tried out:
# 1. Apply sp…