I intended to customize a Normal Mixture distribution and inference with MCMC.
Distribution here:
import jax.numpy as jnp
class mNormal(numpyro.distributions.distribution.Distribution):
support = numpyro.distributions.constraints.real
def __init__(self, pl, validate_args=None):
# self.pl = numpyro.distributions.util.promote_shapes(pl)
self.pl = pl
self.s = 1 / 40
self.mu = 1.
self.sigma = 40 / 160
super(mNormal, self).__init__(batch_shape=jnp.shape(pl), validate_args=validate_args)
def sample(self, key, sample_shape=()):
# assert numpyro.distributions.util.is_prng_key(key)
# keym, key0, key1 = jax.random.split(key, 3)
# shape = sample_shape + self.batch_shape + self.event_shape
# mix = jax.random.bernoulli(keym, self.pl, shape=shape)
# eps0 = jax.random.normal(key0, shape=shape) * self.s
# eps1 = self.mu + jax.random.normal(key1, shape=shape) * self.sigma
# return jnp.where(mix > 0.5, eps1, eps0)
pass
@numpyro.distributions.util.validate_sample
def log_prob(self, value):
prob0 = numpyro.distributions.Normal(0., self.s).log_prob(value)
prob1 = numpyro.distributions.Normal(self.mu, self.sigma).log_prob(value)
return jnp.log(jnp.clip((1 - self.pl) * jnp.exp(prob0) + self.pl * jnp.exp(prob1), jnp.finfo(jnp.float32).tiny, 1-jnp.finfo(jnp.float32).tiny))
Model here:
def model(y, mu):
t0 = numpyro.sample('t0', numpyro.distributions.Uniform(0., float(window)))
if Tau == 0:
light_curve = numpyro.distributions.Normal(t0, scale=Sigma)
pl = jnp.exp(light_curve.log_prob(tlist)) * mu
else:
pl = Co * (1. - jax.scipy.special.erf((Alpha * Sigma ** 2 - (tlist - t0)) / (math.sqrt(2.) * Sigma))) * jnp.exp(-Alpha * (tlist - t0)) * mu
A = numpyro.sample('A', mNormal(pl))
with numpyro.plate('observations', window):
obs = numpyro.sample('obs', numpyro.distributions.Normal(0., scale=std), obs=y-jnp.matmul(AV, A))
return obs
Inference here:
nuts_kernel = numpyro.infer.NUTS(model, adapt_step_size=True)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_samples=1000, num_warmup=500, jit_model_args=True)
mcmc.run(rng_key, y=wave, mu=jnp.sum(wave))
Note: pl
is a window
length jax.numpy.array
As we see, I’ve got a pass
in the distribution mNormal.sample
, which means the mNormal.sample
will not work.
But the inference process can work without error. Therefore, the inference result is wrong.
Anyone please figure out how to customize a distribution properly, or find the bug in my code.