Unexpected init_strategy behaviour for CircularReparam()

Hey hey!

In short:

The median initialisation strategy seems to not work for variables that are reparameterised using CircularReparam.

Longer explanation:

I am having problems with a mixture model of mine (MCMC). I have some (clock) time data, that I would like to fit using VonMises distributions. As I was told by numpyro I reparameterised the prior variables, that I draw from a VonMises with CircularReparam, the location parameters (example: loc_1, loc_2).

My mixture model has 2 clusters and I am pretty sure where the cluster centres are, so I am able to formulate a pretty specific prior. In addition I chose the init_to_median initialisation strategy to ensure that my chains start in the correct position and my clusters converge to the intended position.

Since I am interested in the warmup sample trace I run mcmc.warmup() before running my model to save the warmup samples.

Looking at those traces I realised that the first samples are nowhere near the median of my priors. After some days of debugging, I realised, that when removing the CircularReparam of the location parameters of the VonMises distributions, I saw the intended behaviour of samples near the prior median.
Note: I increased the number of samples for the init_strategy to make sure it’s not just bad luck and tried multiple PRNGKeys.

Question

Is this an intended behaviour? Am I doing sth wrong? Is it a bug and can we fix it? My problem is, that the clusters converge to different means and I lose the cluster identities. Maybe there is another way to ensure that?

Minimal example

Data
key1, key2 = jax.random.split(PRNGKey(42))

data1 = VonMises(0, 2).sample(key1, (1000, ))
data2 = VonMises(pi / 2, 2).sample(key2, (1000, ))

data = np.concatenate([data1, data2])
np.random.shuffle(data)
Model
@numpyro.handlers.reparam(
    config={"loc_1": CircularReparam(), "loc_2": CircularReparam(),}
)
def model(data=None, num_samples=10000, ):

        if data is not None:
            num_samples = len(data)
        
        # Prior Disributions
        weights = sample('weights', Dirichlet(jnp.ones(2)))
        loc_1 = sample('loc_1', VonMises(0, 10))  # reparameterised
        loc_2 = sample('loc_2', VonMises(pi / 2, 10))  #reparameterised

        conc_1 = sample('conc_1', InverseGamma(2, 5))
        conc_2 = sample('conc_2', InverseGamma(2, 5))

        dist_1 = VonMises(loc_1, conc_1)
        dist_2 = VonMises(loc_2, conc_2)

        dists = [dist_1, dist_2]

        # Distributions
        mix = MixtureGeneral(Categorical(probs=weights), dists)

        # Sample Predictions
        with numpyro.plate('plate', num_samples):
            _ = sample("data", mix, obs=data)
Sampler and warmup call
kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
mcmc = MCMC(kernel, num_warmup=100, num_samples=10, num_chains=2)
mcmc.warmup(PRNGKey(40), data=data, collect_warmup=True)

loc_1 = mcmc.get_samples(group_by_chain=True)['loc_1']
loc_2 = mcmc.get_samples(group_by_chain=True)['loc_2']
Results

First warmup sample with reparameterisation:
loc_1 (expected: 0):
chain 1: 1.95
chain 2: -1.69

loc_2 (expected: pi/2):
chain 1: -1.0
chain 2: 0.32

First warmup sample without reparameterisation:
loc_1 (expected: 0):
chain 1: 0.008
chain 2: 0.009

loc_2 (expected: pi/2):
chain 1: 1.574
chain 2: 1.576

It seems like a bug in CircularReparam. Could you file a github issue? You can just draw a sample from the reparam model using seed handler and check if the sample is reasonable.

Thanks @fehiepsi for the quick feedback!
I think the seed handler is a faster way to reproduce the error right?
I am unfortunately not familiar with effect handlers in general but I ran:

y = handlers.seed(model, rng_seed=1)()

I received a NotImplementedError

Then I tried:

@numpyro.handlers.reparam(
    config={"loc_1": CircularReparam()}
)
def tst():
    loc_1 = sample('loc_1', VonMises(0, 100))
    return loc_1

y = handlers.seed(tst, rng_seed=1)()

And again received a NotImplementedError. The last calls from the error stack were:

/numpyro/distributions/distribution.py:260; method: sample_with_intermediates
/numpyro/distributions/distribution.py:248; method: sample

I go ahead and create a Github Issue with the example I provided above, I think I can even slim it down.
If I made a mistake with the error handler let me know, and I try it out.

EDIT:
Github Issue