NumPyro customized distribution sampling did not work

I intended to customize a Normal Mixture distribution and inference with MCMC.
Distribution here:

import jax.numpy as jnp
class mNormal(numpyro.distributions.distribution.Distribution):
support = numpyro.distributions.constraints.real

def __init__(self, pl, validate_args=None):
    # self.pl = numpyro.distributions.util.promote_shapes(pl)
    self.pl = pl
    self.s = 1 / 40
    self.mu = 1.
    self.sigma = 40 / 160
    super(mNormal, self).__init__(batch_shape=jnp.shape(pl), validate_args=validate_args)

def sample(self, key, sample_shape=()):
    # assert numpyro.distributions.util.is_prng_key(key)
    # keym, key0, key1 = jax.random.split(key, 3)
    # shape = sample_shape + self.batch_shape + self.event_shape
    # mix = jax.random.bernoulli(keym, self.pl, shape=shape)
    # eps0 = jax.random.normal(key0, shape=shape) * self.s
    # eps1 = self.mu + jax.random.normal(key1, shape=shape) * self.sigma
    # return jnp.where(mix > 0.5, eps1, eps0)
    pass

@numpyro.distributions.util.validate_sample
def log_prob(self, value):
    prob0 = numpyro.distributions.Normal(0., self.s).log_prob(value)
    prob1 = numpyro.distributions.Normal(self.mu, self.sigma).log_prob(value)
    return jnp.log(jnp.clip((1 - self.pl) * jnp.exp(prob0) + self.pl * jnp.exp(prob1), jnp.finfo(jnp.float32).tiny, 1-jnp.finfo(jnp.float32).tiny))

Model here:

def model(y, mu):
    t0 = numpyro.sample('t0', numpyro.distributions.Uniform(0., float(window)))
    if Tau == 0:
        light_curve = numpyro.distributions.Normal(t0, scale=Sigma)
        pl = jnp.exp(light_curve.log_prob(tlist)) * mu
    else:
        pl = Co * (1. - jax.scipy.special.erf((Alpha * Sigma ** 2 - (tlist - t0)) / (math.sqrt(2.) * Sigma))) * jnp.exp(-Alpha * (tlist - t0)) * mu

    A = numpyro.sample('A', mNormal(pl))

    with numpyro.plate('observations', window):
        obs = numpyro.sample('obs', numpyro.distributions.Normal(0., scale=std), obs=y-jnp.matmul(AV, A))
    return obs

Inference here:

nuts_kernel = numpyro.infer.NUTS(model, adapt_step_size=True)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_samples=1000, num_warmup=500, jit_model_args=True)
mcmc.run(rng_key, y=wave, mu=jnp.sum(wave))

Note: pl is a window length jax.numpy.array

As we see, I’ve got a pass in the distribution mNormal.sample, which means the mNormal.sample will not work.
But the inference process can work without error. Therefore, the inference result is wrong.

Anyone please figure out how to customize a distribution properly, or find the bug in my code.

can you please be more clear about what you expect?

HMC doesn’t use the sample method (except possibly for initialization). the only thing that is required is a differentiable log_prob

Thank you for your reply!

A two-component Mixture Normal distribution whose PDF is:

f(x|w,μ,σ^2)=(1-w) * N(x|μ_{1},σ_{1}^2) + w * N(x|μ_{2},σ_{2}^2)

is included in the model.

are you sure you need a custom distribution? you can put the discrete latent variables in the model; numpyro will sum them out.

def model(probs, locs):
    c = numpyro.sample("c", dist.Categorical(probs))
    numpyro.sample("x", dist.Normal(locs[c], 0.5))

Thank you for your advice!

I modified my model which currently is:

def model(y, mu):
    t0 = numpyro.sample('t0', numpyro.distributions.Uniform(0., float(window)))
    if Tau == 0:
        light_curve = numpyro.distributions.Normal(t0, scale=Sigma)
        pl = jnp.exp(light_curve.log_prob(tlist)) * mu
    else:
        pl = Co * (1. - jax.scipy.special.erf((Alpha * Sigma ** 2 - (tlist - t0)) / (math.sqrt(2.) * Sigma))) * jnp.exp(-Alpha * (tlist - t0)) * mu
    pl = numpyro.primitives.deterministic('pl', jnp.stack([(1 - pl), pl], axis=1))
    with numpyro.plate('cp', window):
        c = numpyro.sample('c', numpyro.distributions.Categorical(pl))
    with numpyro.plate('Ap', window):
        A = numpyro.sample('A', numpyro.distributions.Normal(loc=mu01[c], scale=sigma01[c]))
    with numpyro.plate('observations', window):
        obs = numpyro.sample('obs', numpyro.distributions.Normal(jnp.matmul(AV, A), scale=std), obs=y)
    return obs

But when during sampling initialization, an error arises:

RuntimeError: Cannot find valid initial parameters. Please check your model again.

I can now specify a reasonable initial value. But the codes here seems to have no effect to the error:

initp = init(z={'t0':t0, 'A':A}, potential_energy=None, z_grad=None)
mcmc.run(rng_key, y=wave, mu=mu, init_params=initp)

And additionally, the length of A & pl & obs is more than 1000 (which is large).

After detailed checking, I assume the cause of the error is that pl is too small and the gradient vanished when inferencing.

I tried jax.config.update('jax_enable_x64', True) but the error remains.

i don’t understand your model but presumably you want a single plate of sized window, not three distinct plates.