Skip to content

Combining assert_scalar_positive with jax.vmap #389

@thomaspinder

Description

@thomaspinder

Is it possible to apply assert_scalar_positive to a vector? The MWE below shows what I'd like to do; however, I cannot currently see how this is possible with Chex.

import jax.numpy as jnp 
import jax
from chex import assert_scalar_positive

x_scaler = 1.
x_vector = jnp.array([1.,1.]) 

assert_scalar_positive(x_scaler) # Works
jax.vmap(assert_scalar_positive)(x_vector) # What I'd like

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions