Numpyro.deterministic not showing in mcmc.print_summary


From Bad posterior geometry and how to deal with it — NumPyro documentation:

In _rep_hs_model1 above we used ``numpyro.deterministic <>__ to define scaled_betas . We note that using this primitive is not strictly necessary; however, it has the consequence that scaled_betas will appear in the trace and will thus appear in the summary reported by mcmc.print_summary() . In other words we could also have written:

I ran the following (all same code, but I swapped out printing r_hats for a print_summary statement:

from functools import partial

import numpy as np

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import summary

from numpyro.infer import MCMC, NUTS

assert numpyro.__version__.startswith("0.8.0")

# NB: replace cpu by gpu to run this notebook on gpu

def run_inference(
    model, num_warmup=1000, num_samples=1000, max_tree_depth=10, dense_mass=False

    kernel = NUTS(model, max_tree_depth=max_tree_depth, dense_mass=dense_mass)
    mcmc = MCMC(
    summary_dict = summary(mcmc.get_samples(), group_by_chain=False)

# In this reparameterized model none of the parameters of the distributions
# explicitly depend on other parameters. This model is exactly equivalent
# to _unrep_hs_model but is expressed in a different coordinate system.
def _rep_hs_model1(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    unscaled_betas = numpyro.sample(
        "unscaled_betas", dist.Normal(scale=jnp.ones(X.shape[1]))
    scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
    mean_function =, scaled_betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

# create fake dataset
X = np.random.RandomState(0).randn(100, 500)
Y = X[:, 0]

run_inference(partial(_rep_hs_model1, X, Y))

And I do not see the betas in the trace, just the unscaled betas.

Am I doing something wrong?

you need to pass exclude_deterministic=False to print_summary. see the docs

note you created a summary_dict object but then didn’t print it/inspect it/etc

Ahhhhh, that’s it. Thanks!