While going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.
What's more, I played with some other synthetic examples and they also give unintuitive results, see further down.
Examples
Example from the rethinking
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def model():
μ <~ dist.Normal(178, 20)
σ <~ dist.Uniform(0, 50)
h <~ dist.Normal(μ, σ)
return h
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=model,
model_args=(),
num_samples=10_000
)
fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=128)
axes = axes.reshape(-1)
sns.kdeplot(prior_predictive["μ"], ax=axes[0])
sns.kdeplot(prior_predictive["σ"], ax=axes[1])
sns.kdeplot(prior_predictive["h"], ax=axes[2])
plt.tight_layout()
Result

Expected

Synthetic example 1
In this example I sample an offset from Uniform(0, 1).
Then I sample from Uniform(12 - offset, 12 + offset)
So I expect my samples to be distributed in range [11, 13]
But I get samples in range [-15, 15]
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def example_1():
center = 12
offset <~ dist.Uniform(0, 1)
low = (center - offset)
high = (center + offset)
outcome <~ dist.Uniform(low, high)
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=example_1,
model_args=(),
num_samples=10_000
)
ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");
Result

Synthetic example 2
This is the same example as above, but center variable is passed as argument, not hardcoded, and results are different (although still not in range [11, 13]
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def example_2(center):
offset <~ dist.Uniform(0, 1)
low = (center - offset)
high = (center + offset)
outcome <~ dist.Uniform(low, high)
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=example_2,
model_args=(12, ),
num_samples=10_000
)
ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");
Result

Expectation
For the examples 1 and 2, here's what I'd expect to get:

Environment
Linux-5.8.0-44-generic-x86_64-with-glibc2.10
Python 3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0]
JAX 0.2.8
NetworkX 2.5
JAXlib 0.1.58
mcx 2a2b94801e68d94d86826863eeee80f0b84c390d
While going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.
What's more, I played with some other synthetic examples and they also give unintuitive results, see further down.
Examples
Example from the rethinking
Code
Result
Expected
Synthetic example 1
In this example I sample an
offsetfromUniform(0, 1).Then I sample from
Uniform(12 - offset, 12 + offset)So I expect my samples to be distributed in range
[11, 13]But I get samples in range
[-15, 15]Code
Result
Synthetic example 2
This is the same example as above, but
centervariable is passed as argument, not hardcoded, and results are different (although still not in range[11, 13]Code
Result
Expectation
For the examples
1and2, here's what I'd expect to get:Environment