"Cannot find valid initial parameters" with Delta distribution

I am trying, without success, to fit a model with a discrete observation y = z > 0 where z = Normal(x, sigma) is a latent variable and x the input.

By trying it on a simple example, I get the following error:

~/anaconda3/envs/pytorch_env/lib/python3.7/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    513             if not_jax_tracer(is_valid):
    514                 if device_get(~np.all(is_valid)):
--> 515                     raise RuntimeError("Cannot find valid initial parameters. "
    516                                        "Please check your model again.")
    517 

RuntimeError: Cannot find valid initial parameters. Please check your model again.

Here is the code I used:

import jax
import numpy
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

#generate data
n = 1000
x = numpy.random.normal(size=n)
C = 2
sigma = 3
z = x + C + sigma * numpy.random.normal(size=n)
y = z > 0

def model_example_delta(x, y=None):
    C = numpyro.sample('C', dist.Normal(0, 5))
    sigma = numpyro.sample('tau', dist.HalfNormal(5.))
    with numpyro.plate('data', len(y)):
        z = numpyro.sample('z', dist.Normal(C + x, sigma))
        numpyro.sample('y', dist.Delta(z > 0), obs=y)

nuts_kernel = NUTS(model_example_delta)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, x=x, y=y)
mcmc.print_summary()

I also tried a code with continuous observations instead, and I get the same error.

#generate data
n = 1000
x = numpy.random.normal(size=n)
C = 2
sigma = 3
y = x + C + sigma * numpy.random.normal(size=n)

def model_example_delta(x, y=None):
    C = numpyro.sample('C', dist.Normal(0, 5))
    sigma = numpyro.sample('tau', dist.HalfNormal(5.))
    with numpyro.plate('data', len(y)):
        #numpyro.sample('y', dist.Normal(C + x, sigma), obs=y) #this would work, but not the 2 lines below
        z = numpyro.sample('z', dist.Normal(C + x, sigma))
        numpyro.sample('y', dist.Delta(z), obs=y)

nuts_kernel = NUTS(model_example_delta)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, x=x, y=y)

mcmc.print_summary()

I’m a bit confused, as this example uses Delta and runs normally on my computer.

I saw something potentially related in the Pyro documentation of the Delta distribution, but I’m not sure it’s linked to my problem.

Thanks for the help!

@vincentbt When using Delta, making sure that z = y. There are two usage cases of Delta function:

  • declare a determinitic site sample('v', Delta(v), obs=v): which is equivalent to deterministic('v', v)
  • add a custon log density sample('v', Delta(log_density=v), obs=0.): which is equivalent to factor('v', v)

So I understand that the example model:
z = Normal(x, sigma)
y = z > 0
(where x is the input, z the latent variable and y the observation)
simply cannot be fitted using NumPyro?

I think that model is equivalent to:
y ~ Bernoulli(1 - Normal(x, sigma).cdf(0))

Generally, if y = f(z) where f is an arbitrary deterministic non-bijective transform, then it seems to be tricky to compute p(y) from p(z). I don’t have an answer for it.