Vectorize Plate with Heterogenous Known Parameters

Is there a way to vectorize the plate in the following model?

def model(data):
    assert(type(data['y']) == torch.Tensor)
    assert(data['y'].shape == torch.Size([91, 17]))
    assert(data['observed'] == torch.Tensor)
    assert(data['observed'].shape == torch.Size([91, 17]))
    
    d = data['y'].shape[-1]
    valid_data = data['y'].clone()
    valid_data = torch.nan_to_num(valid_data, nan=0.0)
    
    B = pyro.param('B', torch.zeros([d, d]))

    c_mu = pyro.param('c', torch.tensor(20.0).expand([d]))
    c_sig = pyro.param('c_sig', torch.eye(d), constraint=constraints.positive_definite)
    
  
    for i in pyro.plate('data', data['y'].shape[0]):
        pyro.sample(f"obs_{i}", dist.MultivariateNormal(c_mu+torch.mm(B, data['observed'][i].unsqueeze(-1)).squeeze(-1), c_sig).mask(data['observed'][i]).to_event(1),
                    obs=valid_data[i])

I tried a few iterations of,

with pyro.plate('data', data['y'].shape[0]):
    pyro.sample(f"obs", dist.MultivariateNormal(c_mu+torch.mm(data['observed'], B), c_sig).mask(data['observed']),
                obs=valid_data)

but keep running into shape mismatch errors I haven’t been able to solve. Thanks for any help with the issue!

i don’t believe you can mask out individual components of a distribution with non-trivial event shape (here (17,)).

instead i think you need to subsample each multivariate normal to the observed block (i.e. use the marginal observation distribution). however if you do that then it won’t in general be vectorizable. so instead you need to fill in the non-observed blocks with something trivial like a unit block of the multivariate normal. since this won’t depend on your parameters etc it won’t affect inference