Potential_energy returns nan values

Using potential_energy (Inference Utilities — NumPyro 0.0 documentation) on my model (see below) gives some nan values, when passing samples obtained from NUTS. Why could that be the case?

Are the samples from NUTS that I am passing on the constrained or unconstrained space? Perhaps that affects it.

The model (a regularized horseshoe logistic regression):

    def model(*, N, d, scale_icept, scale_global, nu_global, nu_local, slab_scale, slab_df, x=None, y=None):

        # Priors
        beta0_unscaled = numpyro.sample('beta0_unscaled', dist.Normal(0.0, 1.0))
        beta0 = numpyro.deterministic('beta0', beta0_unscaled * scale_icept)

        tau = numpyro.sample('tau', dist.StudentT(nu_global, 0.0, scale_global))

        lambda_ = numpyro.sample('lambda', dist.StudentT(nu_local, 0.0, 1.0).expand([d]))

        c_aux = numpyro.sample('caux', dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df))

        z = numpyro.sample('z', dist.Normal(0.0, 1.0).expand([d]))

        # Data generation process
        c = numpyro.deterministic('c', slab_scale * jnp.sqrt(c_aux))

        lambda_tilde = numpyro.deterministic('lambda_tilde',
                                             jnp.sqrt(c ** 2 * lambda_ ** 2 / (c ** 2 + tau ** 2 * lambda_ ** 2)))

        beta = numpyro.deterministic('beta', z * lambda_tilde * tau)

        f = beta0 + jnp.matmul(x, beta) 

        # f = jnp.reshape(f, (-1, 1))

        with numpyro.plate('data', N):
            # Likelihood
            numpyro.sample('obs', dist.Bernoulli(logits=f), obs=y)

Thanks for the help.

Hi @Nicola, you can enable this jax_debug_nans flag for debugging. potential_energy accepts the unconstrained samples. You can use unconstrain_fn to convert NUTS samples to unconstrained domain. But I think it’s better to use log_density instead of potential_energy.

Hi, it may be better to use log_density instead of potential_energy for a more accurate evaluation of the log probability density function.

Thanks a lot both !!

However in my application, I have to use samples in the unconstrained space.

The reason is that I am going to do importance sampling on this model, and the proposed samples will come from a mixture of Gaussians, so they lie on an unconstrained space.

I have to evaluate these samples with potential_energy and/or log_density to compute IS weights. But so as I understand, I can only use potential_energy because of the unconstrained samples ?

And why is log_density better ?

The reason I was generating MCMC samples with NUTS is for debugging purposes, to check that passing NUTS samples (after use of unconstrain_fn) to potential_energy returns reasonable values… or well, at least not NaNs.

Yet, I keeping get some NaNs, even after using unconstrain_fn before passing the samples, is that weird ?

Could it be because NUTS is not working well on this model ? But for me it’s strange, a log PDF may be very negative but not sure why NaN.

For example, I get bad r_hat’s for some parameters. I am using :
num_chains = 7 # I have 7 cores
num_warmup = 20000
num_samples = 1000
tree_depth = 30
dense_mass = True

Even with these settings, with the model above, AND only considering 3 features (the first 3 dimensions of x), since the actual data dimension is 1536, I get relatively bad MCMC stats.

I have some code to reproduce below. I am not very experienced with NUTS so maybe I am missing something very obvious.

import numpyro
from numpyro.infer import MCMC, NUTS
from jax import random
import jax.numpy as jnp  # Use JAX's numpy
import numpyro.distributions as dist

numpyro.set_host_device_count(7)
numpyro.enable_x64() # Use 64-bit precision

def run_mcmc(model, model_args, model_hyperparameters, key):
    dimension = model_args['d']
    num_chains = 7 
    num_warmup = 20000
    num_samples = 1000
    tree_depth = 30
    dense_mass = True
    find_heuristic_step_size = True

    nuts_kernel = NUTS(model, max_tree_depth=tree_depth, dense_mass=dense_mass, find_heuristic_step_size=find_heuristic_step_size)
    mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains)
    mcmc.run(key, **model_args)
    mcmc.print_summary()
    
    samples_mcmc = mcmc.get_samples()

    # temp = ('./results/mcmc_samples_numwarmup_{}_numsamples_{}_nchains_{}_treedepth_{}_dimension_{}_densemass_{}_heuristic_{}_modelhypers_scale_icept:{}_scale_global:{}_nu_global:{}_nu_local:{}_slab_scale:{}_slab_df:{}.npz').format(num_warmup, num_samples, num_chains, tree_depth, dimension, dense_mass, find_heuristic_step_size, model_hyperparameters['scale_icept'], model_hyperparameters['scale_global'], model_hyperparameters['nu_global'], model_hyperparameters['nu_local'], model_hyperparameters['slab_scale'], model_hyperparameters['slab_df'])
    # np.savez(temp, **samples_mcmc)

    # return temp

if __name__ == '__main__':
    slab_scale = 1.
    scale_icept = 1.
    nu_global = 0.1
    nu_local = 0.1
    slab_df = 0.1
    scale_global = 0.01

    def model(*, N, d, scale_icept, scale_global, nu_global, nu_local, slab_scale, slab_df, x=None, y=None):
        beta0_unscaled = numpyro.sample('beta0_unscaled', dist.Normal(0.0, 1.0))
        beta0 = beta0_unscaled * scale_icept
        tau = numpyro.sample('tau', dist.StudentT(nu_global, 0.0, scale_global))
        lambda_ = numpyro.sample('lambda', dist.StudentT(nu_local, 0.0, 1.0).expand([d]))
        c_aux = numpyro.sample('caux', dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df))
        z = numpyro.sample('z', dist.Normal(0.0, 1.0).expand([d]))
        c = slab_scale * jnp.sqrt(c_aux)
        lambda_tilde = jnp.sqrt(c ** 2 * lambda_ ** 2 / (c ** 2 + tau ** 2 * lambda_ ** 2))
        beta = z * lambda_tilde * tau
        f = beta0 + jnp.matmul(x, beta)

        with numpyro.plate('data', N):
            numpyro.sample('obs', dist.Bernoulli(logits=f), obs=y)

    # The true data to load, in this test we use random data
    # _, _, y, x = load_data_from_mat('./ovarian.mat')

    key = random.PRNGKey(0)
    key, subkey = random.split(key)

    x, y = random.normal(key, shape=(54, 1536)), random.binomial(subkey, 1, 0.5, shape=(54, ))

    assert scale_global > 0

    model_args = {
        'N': x.shape[0],
        'd': x.shape[1],
        'y': y,
        'x': x,
        'scale_icept': scale_icept,
        'scale_global': scale_global,
        'nu_global': nu_global,
        'nu_local': nu_local,
        'slab_scale': slab_scale,
        'slab_df': slab_df,
    }

    model_hyperparameters = {
        'scale_icept': scale_icept,
        'scale_global': scale_global,
        'nu_global': nu_global,
        'nu_local': nu_local,
        'slab_scale': slab_scale,
        'slab_df': slab_df,
    }

    key, subkey = random.split(key)

    run_mcmc(model, model_args, model_hyperparameters, subkey)

Could you include the code that uses potential_energy? Your code only shows the mcmc loop. I think you can look at the sample that generates NaN, make sure that the sample looks reasonable after applying unconstrain_fn, and finally use potential_energy with that sample. It is easier to debug if you can post here the model together with the NaN sample. My strategy would be to enable NaN-check flag or go line-by-line through the model to see which variable has NaN log prob.

Thank you ! Actually I solved that problem, thanks for the advice.

I would have another question.
I have reparameterized my model as below and now NUTS works excellently.
However now I have a lot of variables. I introduced w_tau, v_tau, w_lambda, v_lambda as intermediates to sample tay and lambda_ . Is there a way I can “hide” the intermediate variables somehow without changing the meaning of the model ? Or is it just a fact of reparameterization that it makes the space of the posterior higher dimensional ?

    def model(*, N, d, scale_icept, scale_global, nu_global, nu_local, slab_scale, slab_df, x=None, y=None):
        # Priors
        beta0_unscaled = numpyro.sample('beta0_unscaled', dist.Normal(0.0, 1.0))

        # beta0 = numpyro.deterministic('beta0', beta0_unscaled * scale_icept)
        beta0 = beta0_unscaled * scale_icept

        # tau = numpyro.sample('tau', dist.StudentT(nu_global, 0.0, scale_global))
        w_tau = numpyro.sample('w_tau', dist.InverseGamma(nu_global / 2.0, nu_global / 2.0)) #### These are actually cauchys
        v_tau = numpyro.sample('v_tau', dist.Normal(0.0, 1.0))
        tau = numpyro.deterministic('tau', scale_global * v_tau * w_tau**(-0.5))

        # lambda_ = numpyro.sample('lambda', dist.StudentT(nu_local, 0.0, 1.0).expand([d]))
        w_lambda = numpyro.sample('w_lambda', dist.InverseGamma(nu_local / 2.0, nu_local / 2.0).expand([d])) #### These are actually cauchys
        v_lambda = numpyro.sample('v_lambda', dist.Normal(0.0, 1.0).expand([d]))
        lambda_ = numpyro.deterministic('lambda',v_lambda * w_lambda**(-0.5))

        c_aux = numpyro.sample('caux', dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df))

        z = numpyro.sample('z', dist.Normal(0.0, 1.0).expand([d]))

        # Data generation process
        # c = numpyro.deterministic('c', slab_scale * jnp.sqrt(c_aux))
        c = slab_scale * jnp.sqrt(c_aux)

        # lambda_tilde = numpyro.deterministic('lambda_tilde',
        #                                      jnp.sqrt(c ** 2 * lambda_ ** 2 / (c ** 2 + tau ** 2 * lambda_ ** 2)))
        lambda_tilde = jnp.sqrt(c ** 2 * lambda_ ** 2 / (c ** 2 + tau ** 2 * lambda_ ** 2))

        # beta = numpyro.deterministic('beta', z * lambda_tilde * tau)
        beta = z * lambda_tilde * tau

        f = beta0 + jnp.matmul(x, beta)

        # f = jnp.reshape(f, (-1, 1))

        with numpyro.plate('data', N):
            # Likelihood
            numpyro.sample('obs', dist.Bernoulli(logits=f), obs=y)
1 Like

okay great!