NumPyro - Help Debugging Dirichlet Process Gaussian Mixture Model

I’m attempting to adapt the Pyro DPGMM tutorial (Dirichlet Process Mixture Models in Pyro — Pyro Tutorials 1.8.4 documentation) to NumPyro but I’m getting very poor results with simple data:

How does one go about debugging why this is happening? I consistently find that the clusters’ means are too close together and tightly packed around 0, even though the means are drawn from N(0, 6*Identity), which is puzzling. This is with 50k optimization steps.

Snippets of my code are below (sorry about the indentation):

        def model(obs):
            with numpyro.plate('beta_plate', truncation_num_clusters - 1):
                beta = numpyro.sample(
                    'beta',
                    numpyro.distributions.Beta(1, concentration_param))

            with numpyro.plate('mean_plate', truncation_num_clusters):
                mean = numpyro.sample(
                    'mean',
                    numpyro.distributions.MultivariateNormal(
                        jnp.zeros(obs_dim),
                        6.0 * jnp.eye(obs_dim)))

            with numpyro.plate('data', num_obs):
                z = numpyro.sample(
                    'z',
                    numpyro.distributions.Categorical(mix_weights(beta=beta)))
                numpyro.sample(
                    'obs',
                    numpyro.distributions.MultivariateNormal(
                        mean[z],
                        0.3 * jnp.eye(obs_dim)),
                    obs=obs)

        def guide(obs):
            q_beta_params = numpyro.param(
                'q_beta_params',
                init_value=jax.random.uniform(
                    key=jax.random.PRNGKey(0),
                    minval=0,
                    maxval=2,
                    shape=(truncation_num_clusters - 1,)),
                constraint=numpyro.distributions.constraints.positive)

            with numpyro.plate('beta_plate', truncation_num_clusters - 1):
                q_beta = numpyro.sample(
                    'beta',
                    numpyro.distributions.Beta(
                        concentration0=jnp.ones(truncation_num_clusters - 1),
                        concentration1=q_beta_params))

            q_means_params = numpyro.param(
                'q_means_params',
                init_value=jax.random.multivariate_normal(
                    key=jax.random.PRNGKey(0),
                    mean=jnp.zeros(obs_dim),
                    cov=model_params['gaussian_mean_prior_cov_scaling'] * jnp.eye(obs_dim),
                    shape=(truncation_num_clusters, )))

            with numpyro.plate('mean_plate', truncation_num_clusters):
                q_mean = numpyro.sample(
                    'mean',
                    numpyro.distributions.MultivariateNormal(
                        q_means_params,
                        model_params['gaussian_cov_scaling'] * jnp.eye(obs_dim)))

            q_z_assignment_params = numpyro.param(
                'q_z_assignment_params',
                init_value=jax.random.dirichlet(key=jax.random.PRNGKey(0),
                                                alpha=jnp.ones(
                                                    truncation_num_clusters) / truncation_num_clusters,
                                                shape=(num_obs,)),
                constraint=numpyro.distributions.constraints.simplex)

            with numpyro.plate('data', num_obs):
                q_z = numpyro.sample(
                    'z',
                    numpyro.distributions.Categorical(probs=q_z_assignment_params))

    optimizer = numpyro.optim.Adam(step_size=learning_rate)
    svi = numpyro.infer.SVI(model,
                            guide,
                            optimizer,
                            loss=numpyro.infer.Trace_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0),
                         num_steps=num_steps,
                         obs=observations,
                         progress_bar=True)

Admittedly, the loss is decreasing

100%|██████████| 15000/15000 [00:40<00:00, 367.81it/s, init loss: 2987.4985, avg. loss [14251-15000]: 981.9764]
Finished SVI (15k Steps) concentration_param=0.01
100%|██████████| 15000/15000 [00:38<00:00, 390.90it/s, init loss: 1824.3541, avg. loss [14251-15000]: 1003.1483]
Finished SVI (15k Steps) concentration_param=0.26
100%|██████████| 15000/15000 [00:36<00:00, 407.82it/s, init loss: 4687.4419, avg. loss [14251-15000]: 1009.7147]
Finished SVI (15k Steps) concentration_param=0.51

Hi @RylanSchaeffer, in NumPyro, performing SVI on models with discrete latent variables is not supported yet (see this and this). After the next release (with GPU is used by default), we’ll try to address those issues to make SVI a bit more robust.

Hi @fehiepsi thanks for clarifying! Two quick follow up questions:

  1. Is there a way for me to discern whether the poor performance of SVI is due to my misuse of the library as opposed to problems with the library?

  2. Do you have a best estimate for when the next release will be?

Good point! We should raise an error if there are discrete latent variables in the model. Let me address that.

Do you have a best estimate for when the next release will be?

Currently, we are working on infer_discrete for Markov models. I’m not sure how long it will take… probably in a few weeks