Thanks a lot both !!
However in my application, I have to use samples in the unconstrained space.
The reason is that I am going to do importance sampling on this model, and the proposed samples will come from a mixture of Gaussians, so they lie on an unconstrained space.
I have to evaluate these samples with potential_energy and/or log_density to compute IS weights. But so as I understand, I can only use potential_energy because of the unconstrained samples ?
And why is log_density better ?
The reason I was generating MCMC samples with NUTS is for debugging purposes, to check that passing NUTS samples (after use of unconstrain_fn) to potential_energy returns reasonable values… or well, at least not NaNs.
Yet, I keeping get some NaNs, even after using unconstrain_fn before passing the samples, is that weird ?
Could it be because NUTS is not working well on this model ? But for me it’s strange, a log PDF may be very negative but not sure why NaN.
For example, I get bad r_hat’s for some parameters. I am using :
num_chains = 7 # I have 7 cores
num_warmup = 20000
num_samples = 1000
tree_depth = 30
dense_mass = True
Even with these settings, with the model above, AND only considering 3 features (the first 3 dimensions of x), since the actual data dimension is 1536, I get relatively bad MCMC stats.
I have some code to reproduce below. I am not very experienced with NUTS so maybe I am missing something very obvious.
import numpyro
from numpyro.infer import MCMC, NUTS
from jax import random
import jax.numpy as jnp # Use JAX's numpy
import numpyro.distributions as dist
numpyro.enable_x64() # Use 64-bit precision
def run_mcmc(model, model_args, model_hyperparameters, key):
dimension = model_args['d']
num_chains = 7
num_warmup = 20000
num_samples = 1000
tree_depth = 30
dense_mass = True
find_heuristic_step_size = True
nuts_kernel = NUTS(model, max_tree_depth=tree_depth, dense_mass=dense_mass, find_heuristic_step_size=find_heuristic_step_size)
mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains), **model_args)
samples_mcmc = mcmc.get_samples()
# temp = ('./results/mcmc_samples_numwarmup_{}_numsamples_{}_nchains_{}_treedepth_{}_dimension_{}_densemass_{}_heuristic_{}_modelhypers_scale_icept:{}_scale_global:{}_nu_global:{}_nu_local:{}_slab_scale:{}_slab_df:{}.npz').format(num_warmup, num_samples, num_chains, tree_depth, dimension, dense_mass, find_heuristic_step_size, model_hyperparameters['scale_icept'], model_hyperparameters['scale_global'], model_hyperparameters['nu_global'], model_hyperparameters['nu_local'], model_hyperparameters['slab_scale'], model_hyperparameters['slab_df'])
# np.savez(temp, **samples_mcmc)
# return temp
if __name__ == '__main__':
slab_scale = 1.
scale_icept = 1.
nu_global = 0.1
nu_local = 0.1
slab_df = 0.1
scale_global = 0.01
def model(*, N, d, scale_icept, scale_global, nu_global, nu_local, slab_scale, slab_df, x=None, y=None):
beta0_unscaled = numpyro.sample('beta0_unscaled', dist.Normal(0.0, 1.0))
beta0 = beta0_unscaled * scale_icept
tau = numpyro.sample('tau', dist.StudentT(nu_global, 0.0, scale_global))
lambda_ = numpyro.sample('lambda', dist.StudentT(nu_local, 0.0, 1.0).expand([d]))
c_aux = numpyro.sample('caux', dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df))
z = numpyro.sample('z', dist.Normal(0.0, 1.0).expand([d]))
c = slab_scale * jnp.sqrt(c_aux)
lambda_tilde = jnp.sqrt(c ** 2 * lambda_ ** 2 / (c ** 2 + tau ** 2 * lambda_ ** 2))
beta = z * lambda_tilde * tau
f = beta0 + jnp.matmul(x, beta)
with numpyro.plate('data', N):
numpyro.sample('obs', dist.Bernoulli(logits=f), obs=y)
# The true data to load, in this test we use random data
# _, _, y, x = load_data_from_mat('./ovarian.mat')
key = random.PRNGKey(0)
key, subkey = random.split(key)
x, y = random.normal(key, shape=(54, 1536)), random.binomial(subkey, 1, 0.5, shape=(54, ))
assert scale_global > 0
model_args = {
'N': x.shape[0],
'd': x.shape[1],
'y': y,
'x': x,
'scale_icept': scale_icept,
'scale_global': scale_global,
'nu_global': nu_global,
'nu_local': nu_local,
'slab_scale': slab_scale,
'slab_df': slab_df,
model_hyperparameters = {
'scale_icept': scale_icept,
'scale_global': scale_global,
'nu_global': nu_global,
'nu_local': nu_local,
'slab_scale': slab_scale,
'slab_df': slab_df,
key, subkey = random.split(key)
run_mcmc(model, model_args, model_hyperparameters, subkey)