Hey hey!
In short:
The median initialisation strategy seems to not work for variables that are reparameterised using CircularReparam.
Longer explanation:
I am having problems with a mixture model of mine (MCMC). I have some (clock) time data, that I would like to fit using VonMises distributions. As I was told by numpyro I reparameterised the prior variables, that I draw from a VonMises with CircularReparam, the location parameters (example: loc_1, loc_2).
My mixture model has 2 clusters and I am pretty sure where the cluster centres are, so I am able to formulate a pretty specific prior. In addition I chose the init_to_median initialisation strategy to ensure that my chains start in the correct position and my clusters converge to the intended position.
Since I am interested in the warmup sample trace I run mcmc.warmup() before running my model to save the warmup samples.
Looking at those traces I realised that the first samples are nowhere near the median of my priors. After some days of debugging, I realised, that when removing the CircularReparam of the location parameters of the VonMises distributions, I saw the intended behaviour of samples near the prior median.
Note: I increased the number of samples for the init_strategy to make sure it’s not just bad luck and tried multiple PRNGKeys.
Question
Is this an intended behaviour? Am I doing sth wrong? Is it a bug and can we fix it? My problem is, that the clusters converge to different means and I lose the cluster identities. Maybe there is another way to ensure that?
Minimal example
Data
key1, key2 = jax.random.split(PRNGKey(42))
data1 = VonMises(0, 2).sample(key1, (1000, ))
data2 = VonMises(pi / 2, 2).sample(key2, (1000, ))
data = np.concatenate([data1, data2])
np.random.shuffle(data)
Model
@numpyro.handlers.reparam(
config={"loc_1": CircularReparam(), "loc_2": CircularReparam(),}
)
def model(data=None, num_samples=10000, ):
if data is not None:
num_samples = len(data)
# Prior Disributions
weights = sample('weights', Dirichlet(jnp.ones(2)))
loc_1 = sample('loc_1', VonMises(0, 10)) # reparameterised
loc_2 = sample('loc_2', VonMises(pi / 2, 10)) #reparameterised
conc_1 = sample('conc_1', InverseGamma(2, 5))
conc_2 = sample('conc_2', InverseGamma(2, 5))
dist_1 = VonMises(loc_1, conc_1)
dist_2 = VonMises(loc_2, conc_2)
dists = [dist_1, dist_2]
# Distributions
mix = MixtureGeneral(Categorical(probs=weights), dists)
# Sample Predictions
with numpyro.plate('plate', num_samples):
_ = sample("data", mix, obs=data)
Sampler and warmup call
kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
mcmc = MCMC(kernel, num_warmup=100, num_samples=10, num_chains=2)
mcmc.warmup(PRNGKey(40), data=data, collect_warmup=True)
loc_1 = mcmc.get_samples(group_by_chain=True)['loc_1']
loc_2 = mcmc.get_samples(group_by_chain=True)['loc_2']
Results
First warmup sample with reparameterisation:
loc_1 (expected: 0):
chain 1: 1.95
chain 2: -1.69
loc_2 (expected: pi/2):
chain 1: -1.0
chain 2: 0.32
First warmup sample without reparameterisation:
loc_1 (expected: 0):
chain 1: 0.008
chain 2: 0.009
loc_2 (expected: pi/2):
chain 1: 1.574
chain 2: 1.576