Conditional independence across features within a mixture model

Hi all, I am new to Pyro and attempting a clustering model with conditional independence across features. Below is a non-functioning piece of code which outlines a simple version of my problem (a mixture of Gaussians with diagonal covariance matrices).

Does anyone have suggestions on how to structure the nesting of the measurement plate within the data plate?

Note that the model I wish to use will be more complex, but this behaviour of conditional independence across features will be present (whereas I could move to a MVN in this toy example circumventing my problem inelegantly).

The current setup gives me a shape mismatch error.


import numpy as np
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import (
    MCMC, 
    NUTS, 
)
from pyro.ops.indexing import Vindex
from pyro.infer.mcmc.util import initialize_model

def model(data, K):

    N, P = data.shape

    # Define the plates
    measurement_plate = pyro.plate("measurement_plate", P, dim=-2)
    data_plate = pyro.plate("data_plate", N, dim=-1)
    cluster_plate = pyro.plate("cluster_plate", K, dim=-1)
   
    with cluster_plate as k:
        concentration = pyro.sample('concentration', dist.Gamma(1.0, 0.25))
        with measurement_plate as p:
            mu = pyro.sample('mu', dist.Normal(0, 0.5))
            sigma = pyro.sample('sigma', dist.Gamma(1, 1))

    # Component weights
    weights = pyro.sample("weights", dist.Dirichlet(concentration / K))

    with data_plate as n:
        # Cluster allocation variable
        z = pyro.sample("z", dist.Categorical(weights))
        mu_rel = Vindex(mu)[: , z]
        sigma_rel = Vindex(sigma)[:, z] 
        with measurement_plate as p:
            # I wish to describe samples based on the cluster specific parameter for that measurement
            # This line is wrong, but hopefully conveys my intention
            pyro.sample("obs", dist.Normal(mu_rel[n, p], sigma_rel[n, p]), obs=data[n, p])

if __name__ == "__main__":
    N = 10
    P = 3
    K = 4

    NUM_SAMPLES = int(5)
    NUM_WARMUP=int(1)
    NUM_CHAINS=int(1)
    USE_JIT=False

    data = torch.rand(N, P)

    init_params, potential_fn, transforms, _ = initialize_model(
        model,
        model_args=(data, K),
        num_chains=NUM_CHAINS,
        jit_compile=USE_JIT,
        skip_jit_warnings=True,
    )
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(
        nuts_kernel,
        num_samples=NUM_SAMPLES,
        warmup_steps=NUM_WARMUP,
        num_chains=NUM_CHAINS,
        initial_params=init_params,
        transforms=transforms,
    )

    mcmc.run(data)
    samples = mcmc.get_samples()

In numpyro one can do a model like this via:


@config_enumerate
def model(data, K):
    N, P = data.shape

    # Global variables.
    weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
    with numpyro.plate("components", K, dim=-2):
        with numpyro.plate("measurements", P, dim=-1):
            scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
            locs = numpyro.sample("locs", dist.Normal(0.0, 10.0))

    with numpyro.plate("data", N, dim=-2) as n:
        # Local variables.
        assignment = numpyro.sample("assignment", dist.Categorical(weights))
        with numpyro.plate("measurements", P, dim=-1) as p:
            numpyro.sample("obs", dist.Normal(Vindex(locs)[..., assignment, p], Vindex(scale)[..., assignment, p]), obs=data)