Getting 'nan' values in NeutraReparam

I am using NeutraReparam on the following model:

def model(county_idx = county, log_radon = log_radon, floor = floor_measure, J=J, N=N):
    sigma_y = numpyro.sample("sigma_y", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
    sigma_beta = numpyro.sample("sigma_beta", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
    sigma_alpha = numpyro.sample("sigma_alpha", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))

    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0.0, 10))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0.0, 10))

    with numpyro.plate("J", J):
        alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
        beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta))
    # print(alpha.shape)

    mu = alpha[county_idx] + beta[county_idx] * floor
    with numpyro.plate("N", N):
        numpyro.sample("obs", dist.Normal(mu, sigma_y), obs=log_radon)    

I am using the following code to implement NeutraReparam based on this example:

from numpyro.infer.autoguide import AutoBNAFNormal, AutoIAFNormal
from numpyro.infer.reparam import NeuTraReparam

svi_model = model
guide = AutoIAFNormal(svi_model)
svi = SVI(svi_model, guide, numpyro.optim.Adam(9e-3), Trace_ELBO(100))
svi_result =, 5)
loss = svi_result.losses
params = svi_result.params
neutra = NeuTraReparam(guide, params)
neutra_model = neutra.reparam(svi_model)
nuts_kernel = NUTS(neutra_model)
mcmc = MCMC(
    num_warmup = 500,
    num_samples = 1000

But, I am getting ‘nan’ values in the params due to which I am unable to go ahead with reparameterisation. Please help me figure out the potential issues in my code.

hard to say but running svi for 5 steps is not sufficient. also your learning rate may be too high

1 Like

Learning rate seems to do the trick. Thanks!

Though, with the following model, I am still getting nan values (I have tried out learning rate as low as 1e-11) using NeutraReparam:

def gp_log_scales(log_sd, log_lengthscale, log_eigenvalues):
    return log_sd + .25 * jnp.log(2*jnp.pi) + .5 * log_lengthscale + jnp.exp(2*log_lengthscale) * log_eigenvalues

def gp_scales(log_sd, log_lengthscale, log_eigenvalues):
    return jnp.exp(gp_log_scales(log_sd, log_lengthscale, log_eigenvalues))
def homoskedastic_gp_reparam(n_observations = n_observations, X = X, y = y, 
                            n_functions = n_functions, boundary_factor = boundary_factor, 
                            log_eigenvalues = log_eigenvalues, likelihood = True):
    mean_sd_raw = numpyro.sample("mean_sd_raw", dist.Normal(0,1))
    mean_sd = jnp.exp(mean_sd_raw)

    mean_lengthscale_raw = numpyro.sample("mean_lengthscale_raw", dist.Normal(0,1))
    mean_lengthscale = jnp.exp(mean_lengthscale_raw)

    mean_beta_scales = gp_scales(jnp.log(mean_sd), jnp.log(mean_lengthscale), log_eigenvalues)

    mean_beta = numpyro.sample("mean_beta", dist.Normal(0,mean_beta_scales))

    mean_intercept = numpyro.sample("mean_intercept", dist.Normal(0,1))

    lsd_intercept_raw = numpyro.sample("lsd_intercept_raw", dist.Normal(-2,1))

    y_mean = mean_intercept + X @ mean_beta
    y_sd = jnp.exp(lsd_intercept_raw)
    if likelihood:
        with numpyro.plate("n_observations", n_observations):
            numpyro.sample("y", dist.Normal(y_mean, y_sd), obs=y)
    # return homoskedastic_gp_reparam

Please suggest some possible issues with the code/ some modification for it to work.

Also, while sampling from the model for which keeping learning rate low worked out, the sampling speed is very slow (Only 275 warmup samples obtained in 17 minutes). What are some possible ways to speeden it up?

reparameterizing a model with a learned neural network is something that was explored in some research papers but it’s hardly proven as some kind of robust methodology that always works. i’d generally advise against putting too much time in that direction as it may not work well. it’s probably better to just try to get good performance with your model using some of the suggestions in this tutorial. e.g. you might reparameterize mean_beta or use 64-bit precision. unfortunately that are no magic tricks that always make bayesian inference easy.

1 Like

Thank you for sharing the tutorial. It seems to be very relevant to the things that I am currently working upon.