Skip to content
Discussion options

You must be logged in to vote

Been exploring this structure to understand this machanic better:

import jax.numpy as jnp
import jax

import flax.linen as nn


class Block(nn.Module):
    @nn.compact
    def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray:
        x = nn.Dense(16)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.Dropout(0.2, deterministic=not training)(x)
        x = nn.relu(x)
        return x

Once initialized like this:

key = jax.random.PRNGKey(0)
x = jnp.ones((10, 2))
module = Block()
init_rngs = dict(zip(["params", "dropout"], jax.random.split(key, 2)))

variables = module.init(init_rngs, x, True)

There are a couple of scenarios I tried out:

# 1. Apply sp…

Replies: 3 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants
Converted from issue

This discussion was converted from issue #1382 on April 27, 2022 20:43.