Upon running the following code:
d = 500 # Specify the desired length
c = jax.random.uniform(rng_key, shape=(d,))
def ll_pdf(t, J, tau, mu, theta):
return dist.Normal(jnp.full(J,mu*(1-t)), jnp.exp(tau*(1-t))).log_prob(theta) - dist.Normal(jnp.full(J,mu), jnp.exp(tau)).log_prob(theta)
def distribution(J=d, c = c):
mu = numpyro.sample('mu', dist.Normal(0, 1))
tau = numpyro.sample('tau', dist.Normal(0,1))
theta = numpyro.sample('theta', dist.Normal(jnp.full(J,mu), jnp.exp(tau)))
numpyro.factor('theta_ll', ll_pdf(c, J, tau, mu, theta))
guide = AutoIAFNormal(distribution)
svi = SVI(distribution, guide, optim.Adam(0.01), Trace_ELBO(100))
svi_result = svi.run(rng_key, num_steps=1000)
svi_result.params
neutra = NeuTraReparam(guide, svi_result.params)
neutra_model = neutra.reparam(distribution)
nuts_kernel = NUTS(neutra_model)
mcmc = MCMC(
nuts_kernel,
num_warmup=500,
num_samples=1000
)
mcmc.run(rng_key)
mcmc.print_summary(exclude_deterministic=False)
I get nan values in svi.params and the following error upon running mcmc.run:
ValueError: Unit distribution got invalid log_factor parameter.
Please help me debug the error.