I am using NeutraReparam on the following model:
def model(county_idx = county, log_radon = log_radon, floor = floor_measure, J=J, N=N):
sigma_y = numpyro.sample("sigma_y", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
sigma_beta = numpyro.sample("sigma_beta", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
sigma_alpha = numpyro.sample("sigma_alpha", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0.0, 10))
mu_beta = numpyro.sample("mu_beta", dist.Normal(0.0, 10))
with numpyro.plate("J", J):
alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta))
# print(alpha.shape)
mu = alpha[county_idx] + beta[county_idx] * floor
with numpyro.plate("N", N):
numpyro.sample("obs", dist.Normal(mu, sigma_y), obs=log_radon)
I am using the following code to implement NeutraReparam based on this example:
from numpyro.infer.autoguide import AutoBNAFNormal, AutoIAFNormal
from numpyro.infer.reparam import NeuTraReparam
svi_model = model
guide = AutoIAFNormal(svi_model)
svi = SVI(svi_model, guide, numpyro.optim.Adam(9e-3), Trace_ELBO(100))
svi_result = svi.run(rng_key, 5)
loss = svi_result.losses
params = svi_result.params
neutra = NeuTraReparam(guide, params)
neutra_model = neutra.reparam(svi_model)
nuts_kernel = NUTS(neutra_model)
mcmc = MCMC(
nuts_kernel,
num_warmup = 500,
num_samples = 1000
)
mcmc.run(rng_key)
mcmc.print_summary()
But, I am getting ‘nan’ values in the params due to which I am unable to go ahead with reparameterisation. Please help me figure out the potential issues in my code.