Hi, It seems there are some trouble when using toch.diag function in pyro.plate.
just like this:
val = t.ones([2,2])
def model():
with pyro.plate('dim',5) :
lv_z = pyro.sample("lv_z", dist.Categorical(t.tensor([0.5,0.5])))
print(lv_z.shape)
print(val[lv_z].shape)
print(t.diag(val[lv_z]).shape)
model()
The output is:
torch.Size([5])
torch.Size([5, 2])
torch.Size([2])
I think the last shape should be [5,2,2], but I get [2].
Is there any bug in my code?
Please give some suggestions.
Thanks!
fritzo
2
This is an issue with the torch.daig() interface not supporting broadcasting. I recommend avoiding torch.diag().
- When
x is a vector, torch.diag(x) is a matrix whose diagonal is x.
- When
x is a matrix, torch.diag(x) is a vector which is the diagonal of x.
Thus for vectors torch.diag(torch.diag(x)) == x.
Hi, thanks for you reply.
I replaced torch.diag() with torch.diagflat(). But it generate the variables with the shape torch.Size([10, 10]) ,which is different from what I want.
How can I generate variable with the shape [5,2,2], which means for each independent sample, a diagonal matrix with shape [2,2] is generated?
Thanks!
@FreeAJust You can use the following trick
x = torch.rand(5, 2)
y = x.new_zeros((5, 2, 2))
y.view(5, -1)[..., ::2+1] = x
1 Like