Unexpected(?) Enumerate Error with AutoNormal

Below is a snippet of some code I have. I want to subsample a specific dimension to run SVI with the AutoNormal guide, but different sections of my output have different likelihoods, so I’ve created this nested plate structure.

 with numpyro.plate("data", size = self.n, subsample_size=100 if inference_method == "svi" else self.n, dim=-2) as ind:
        
            mu = f( X[ind], ...) # tensor of size (K, 100, T) 
            
            for family in data_set:
                k_indices = data_set[family]["indices"]
                linear_predictor = mu[k_indices]
                exposure = data_set[family]["exposure"][:, ind, :]
                obs = data_set[family]["Y"][:, ind, :]
                mask = data_set[family]["mask"][:, ind, :]


                with numpyro.plate(f"features_{family}", size=len(k_indices), dim=-3):
                    if family == "gaussian":
                        rate = linear_predictor
                        dist = Normal(rate, expanded_sigmas[:, ind, :] / exposure)
                    elif family == "poisson":
                        rate = jnp.exp(linear_predictor + exposure)
                        dist = Poisson(rate) 
                    elif family == "binomial":
                        rate = linear_predictor
                        dist = BinomialLogits(logits = rate, total_count=exposure.astype(int))
                    elif family == "beta":
                        rate = jsci.special.expit(linear_predictor)
                        dist = BetaProportion(rate, exposure * sigma_beta)

                    y = sample(f"likelihood_{family}", dist, obs, obs_mask=mask) 

I’m using numpyro version 0.18.0, but i’m running into the strange error

RuntimeError: This algorithm might only work for discrete sites with enumerate support. But the MaskedDistribution distribution at site likelihood_poisson_unobserved does not have enumerate support.

I’m unsure where this is coming from, since i’m using Trace_ELBO as the loss (which doesn’t handle enumeration support) and the AutoNormal guide

As usual, I answer my own questions.

This section referencing the .mask function answered my question for me. I guess its because obs_mask creates a latent discrete variable under the hood