NeutraReparam throwing error on model built using numpyro.factor

Upon running the following code:

d = 500  # Specify the desired length
c = jax.random.uniform(rng_key, shape=(d,))

def ll_pdf(t, J, tau, mu, theta):
    return dist.Normal(jnp.full(J,mu*(1-t)), jnp.exp(tau*(1-t))).log_prob(theta) - dist.Normal(jnp.full(J,mu), jnp.exp(tau)).log_prob(theta)

def distribution(J=d, c = c):
    mu = numpyro.sample('mu', dist.Normal(0, 1))
    tau = numpyro.sample('tau', dist.Normal(0,1))
    theta = numpyro.sample('theta', dist.Normal(jnp.full(J,mu), jnp.exp(tau)))
    numpyro.factor('theta_ll', ll_pdf(c, J, tau, mu, theta))

guide = AutoIAFNormal(distribution)
svi = SVI(distribution, guide, optim.Adam(0.01), Trace_ELBO(100))

svi_result = svi.run(rng_key, num_steps=1000)

svi_result.params

neutra = NeuTraReparam(guide, svi_result.params)
neutra_model = neutra.reparam(distribution)

nuts_kernel = NUTS(neutra_model)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=500,
    num_samples=1000
)
mcmc.run(rng_key)
mcmc.print_summary(exclude_deterministic=False)

I get nan values in svi.params and the following error upon running mcmc.run:

ValueError: Unit distribution got invalid log_factor parameter.

Please help me debug the error.

are you sure ll_pdf is bounded from above?

e.g. numpyro.factor("my_factor", -z ** 2) is bounded from above but
e.g. numpyro.factor("my_factor", z ** 2) is not

Is there any possible solution for this if I wish to use the same model?

if it’s actually the case that it’s not bounded from above i don’t think any inference algorithm will give you sensible results since the distribution isn’t normalizable