Hi,
From Bad posterior geometry and how to deal with it — NumPyro documentation:
In
_rep_hs_model1
above we used ``numpyro.deterministic<http://num.pyro.ai/en/stable/primitives.html?highlight=deterministic#numpyro.primitives.deterministic>
__ to definescaled_betas
. We note that using this primitive is not strictly necessary; however, it has the consequence thatscaled_betas
will appear in the trace and will thus appear in the summary reported bymcmc.print_summary()
. In other words we could also have written:
I ran the following (all same code, but I swapped out printing r_hats for a print_summary statement:
from functools import partial
import numpy as np
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import summary
from numpyro.infer import MCMC, NUTS
assert numpyro.__version__.startswith("0.8.0")
# NB: replace cpu by gpu to run this notebook on gpu
numpyro.set_platform("cpu")
def run_inference(
model, num_warmup=1000, num_samples=1000, max_tree_depth=10, dense_mass=False
):
kernel = NUTS(model, max_tree_depth=max_tree_depth, dense_mass=dense_mass)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1,
progress_bar=False,
)
mcmc.run(random.PRNGKey(0))
summary_dict = summary(mcmc.get_samples(), group_by_chain=False)
mcmc.print_summary(0.95)
# In this reparameterized model none of the parameters of the distributions
# explicitly depend on other parameters. This model is exactly equivalent
# to _unrep_hs_model but is expressed in a different coordinate system.
def _rep_hs_model1(X, Y):
lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
unscaled_betas = numpyro.sample(
"unscaled_betas", dist.Normal(scale=jnp.ones(X.shape[1]))
)
scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
mean_function = jnp.dot(X, scaled_betas)
numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)
# create fake dataset
X = np.random.RandomState(0).randn(100, 500)
Y = X[:, 0]
run_inference(partial(_rep_hs_model1, X, Y))
And I do not see the betas in the trace, just the unscaled betas.
Am I doing something wrong?