I intended to customize a Normal Mixture distribution and inference with MCMC.
Distribution here:
import jax.numpy as jnp
class mNormal(numpyro.distributions.distribution.Distribution):
support = numpyro.distributions.constraints.real
def __init__(self, pl, validate_args=None):
# self.pl = numpyro.distributions.util.promote_shapes(pl)
self.pl = pl
self.s = 1 / 40
self.mu = 1.
self.sigma = 40 / 160
super(mNormal, self).__init__(batch_shape=jnp.shape(pl), validate_args=validate_args)
def sample(self, key, sample_shape=()):
# assert numpyro.distributions.util.is_prng_key(key)
# keym, key0, key1 = jax.random.split(key, 3)
# shape = sample_shape + self.batch_shape + self.event_shape
# mix = jax.random.bernoulli(keym, self.pl, shape=shape)
# eps0 = jax.random.normal(key0, shape=shape) * self.s
# eps1 = self.mu + jax.random.normal(key1, shape=shape) * self.sigma
# return jnp.where(mix > 0.5, eps1, eps0)
pass
@numpyro.distributions.util.validate_sample
def log_prob(self, value):
prob0 = numpyro.distributions.Normal(0., self.s).log_prob(value)
prob1 = numpyro.distributions.Normal(self.mu, self.sigma).log_prob(value)
return jnp.log(jnp.clip((1 - self.pl) * jnp.exp(prob0) + self.pl * jnp.exp(prob1), jnp.finfo(jnp.float32).tiny, 1-jnp.finfo(jnp.float32).tiny))
Model here:
def model(y, mu):
t0 = numpyro.sample('t0', numpyro.distributions.Uniform(0., float(window)))
if Tau == 0:
light_curve = numpyro.distributions.Normal(t0, scale=Sigma)
pl = jnp.exp(light_curve.log_prob(tlist)) * mu
else:
pl = Co * (1. - jax.scipy.special.erf((Alpha * Sigma ** 2 - (tlist - t0)) / (math.sqrt(2.) * Sigma))) * jnp.exp(-Alpha * (tlist - t0)) * mu
A = numpyro.sample('A', mNormal(pl))
with numpyro.plate('observations', window):
obs = numpyro.sample('obs', numpyro.distributions.Normal(0., scale=std), obs=y-jnp.matmul(AV, A))
return obs
Inference here:
nuts_kernel = numpyro.infer.NUTS(model, adapt_step_size=True)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_samples=1000, num_warmup=500, jit_model_args=True)
mcmc.run(rng_key, y=wave, mu=jnp.sum(wave))
Note: pl is a window length jax.numpy.array
As we see, I’ve got a pass in the distribution mNormal.sample, which means the mNormal.sample will not work.
But the inference process can work without error. Therefore, the inference result is wrong.
Anyone please figure out how to customize a distribution properly, or find the bug in my code.