Suppose batch size is 8 and vector length is 50. Basically, this is what happens:
a = torch.randn(8, 50)
b = torch.randn(8, 50)
with pyro.plate("c"):
d = pyro.sample('d', Normal(a, b).to_event(1))
e = Normal(a, b).to_event(1).log_prob(d)
We get e = tensor([nan] * 8)
.
I can see why Pyro does not know how to calculate the log probability for a converted distribution, but this is the same with that in the semi-supervised tutorial. So if I want to sample from MultivariateNormal(a[i], diag(b[i]))
for each i=1,...,8
in the mini-batch, what’s a correct and efficient way of dong it?