Hi, It seems there are some trouble when using toch.diag function in pyro.plate.
just like this:
val = t.ones([2,2]) def model(): with pyro.plate('dim',5) : lv_z = pyro.sample("lv_z", dist.Categorical(t.tensor([0.5,0.5]))) print(lv_z.shape) print(val[lv_z].shape) print(t.diag(val[lv_z]).shape) model()
The output is:
torch.Size() torch.Size([5, 2]) torch.Size()
I think the last shape should be [5,2,2], but I get .
Is there any bug in my code?
Please give some suggestions.