Is there a way to vectorize the plate in the following model?
def model(data):
assert(type(data['y']) == torch.Tensor)
assert(data['y'].shape == torch.Size([91, 17]))
assert(data['observed'] == torch.Tensor)
assert(data['observed'].shape == torch.Size([91, 17]))
d = data['y'].shape[-1]
valid_data = data['y'].clone()
valid_data = torch.nan_to_num(valid_data, nan=0.0)
B = pyro.param('B', torch.zeros([d, d]))
c_mu = pyro.param('c', torch.tensor(20.0).expand([d]))
c_sig = pyro.param('c_sig', torch.eye(d), constraint=constraints.positive_definite)
for i in pyro.plate('data', data['y'].shape[0]):
pyro.sample(f"obs_{i}", dist.MultivariateNormal(c_mu+torch.mm(B, data['observed'][i].unsqueeze(-1)).squeeze(-1), c_sig).mask(data['observed'][i]).to_event(1),
obs=valid_data[i])
I tried a few iterations of,
with pyro.plate('data', data['y'].shape[0]):
pyro.sample(f"obs", dist.MultivariateNormal(c_mu+torch.mm(data['observed'], B), c_sig).mask(data['observed']),
obs=valid_data)
but keep running into shape mismatch errors I haven’t been able to solve. Thanks for any help with the issue!