I am trying, without success, to fit a model with a discrete observation y = z > 0
where z = Normal(x, sigma)
is a latent variable and x the input.
By trying it on a simple example, I get the following error:
~/anaconda3/envs/pytorch_env/lib/python3.7/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
513 if not_jax_tracer(is_valid):
514 if device_get(~np.all(is_valid)):
--> 515 raise RuntimeError("Cannot find valid initial parameters. "
516 "Please check your model again.")
517
RuntimeError: Cannot find valid initial parameters. Please check your model again.
Here is the code I used:
import jax
import numpy
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
#generate data
n = 1000
x = numpy.random.normal(size=n)
C = 2
sigma = 3
z = x + C + sigma * numpy.random.normal(size=n)
y = z > 0
def model_example_delta(x, y=None):
C = numpyro.sample('C', dist.Normal(0, 5))
sigma = numpyro.sample('tau', dist.HalfNormal(5.))
with numpyro.plate('data', len(y)):
z = numpyro.sample('z', dist.Normal(C + x, sigma))
numpyro.sample('y', dist.Delta(z > 0), obs=y)
nuts_kernel = NUTS(model_example_delta)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, x=x, y=y)
mcmc.print_summary()
I also tried a code with continuous observations instead, and I get the same error.
#generate data
n = 1000
x = numpy.random.normal(size=n)
C = 2
sigma = 3
y = x + C + sigma * numpy.random.normal(size=n)
def model_example_delta(x, y=None):
C = numpyro.sample('C', dist.Normal(0, 5))
sigma = numpyro.sample('tau', dist.HalfNormal(5.))
with numpyro.plate('data', len(y)):
#numpyro.sample('y', dist.Normal(C + x, sigma), obs=y) #this would work, but not the 2 lines below
z = numpyro.sample('z', dist.Normal(C + x, sigma))
numpyro.sample('y', dist.Delta(z), obs=y)
nuts_kernel = NUTS(model_example_delta)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, x=x, y=y)
mcmc.print_summary()
I’m a bit confused, as this example uses Delta and runs normally on my computer.
I saw something potentially related in the Pyro documentation of the Delta distribution, but I’m not sure it’s linked to my problem.
Thanks for the help!