Hi all, I am new to Pyro and attempting a clustering model with conditional independence across features. Below is a non-functioning piece of code which outlines a simple version of my problem (a mixture of Gaussians with diagonal covariance matrices).

Does anyone have suggestions on how to structure the nesting of the measurement plate within the data plate?

Note that the model I wish to use will be more complex, but this behaviour of conditional independence across features will be present (whereas I could move to a MVN in this toy example circumventing my problem inelegantly).

The current setup gives me a shape mismatch error.

```
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import (
MCMC,
NUTS,
)
from pyro.ops.indexing import Vindex
from pyro.infer.mcmc.util import initialize_model
def model(data, K):
N, P = data.shape
# Define the plates
measurement_plate = pyro.plate("measurement_plate", P, dim=-2)
data_plate = pyro.plate("data_plate", N, dim=-1)
cluster_plate = pyro.plate("cluster_plate", K, dim=-1)
with cluster_plate as k:
concentration = pyro.sample('concentration', dist.Gamma(1.0, 0.25))
with measurement_plate as p:
mu = pyro.sample('mu', dist.Normal(0, 0.5))
sigma = pyro.sample('sigma', dist.Gamma(1, 1))
# Component weights
weights = pyro.sample("weights", dist.Dirichlet(concentration / K))
with data_plate as n:
# Cluster allocation variable
z = pyro.sample("z", dist.Categorical(weights))
mu_rel = Vindex(mu)[: , z]
sigma_rel = Vindex(sigma)[:, z]
with measurement_plate as p:
# I wish to describe samples based on the cluster specific parameter for that measurement
# This line is wrong, but hopefully conveys my intention
pyro.sample("obs", dist.Normal(mu_rel[n, p], sigma_rel[n, p]), obs=data[n, p])
if __name__ == "__main__":
N = 10
P = 3
K = 4
NUM_SAMPLES = int(5)
NUM_WARMUP=int(1)
NUM_CHAINS=int(1)
USE_JIT=False
data = torch.rand(N, P)
init_params, potential_fn, transforms, _ = initialize_model(
model,
model_args=(data, K),
num_chains=NUM_CHAINS,
jit_compile=USE_JIT,
skip_jit_warnings=True,
)
nuts_kernel = NUTS(potential_fn=potential_fn)
mcmc = MCMC(
nuts_kernel,
num_samples=NUM_SAMPLES,
warmup_steps=NUM_WARMUP,
num_chains=NUM_CHAINS,
initial_params=init_params,
transforms=transforms,
)
mcmc.run(data)
samples = mcmc.get_samples()
```