Hi, It seems there are some trouble when using toch.diag function in pyro.plate.
just like this:
val = t.ones([2,2])
def model():
with pyro.plate('dim',5) :
lv_z = pyro.sample("lv_z", dist.Categorical(t.tensor([0.5,0.5])))
print(lv_z.shape)
print(val[lv_z].shape)
print(t.diag(val[lv_z]).shape)
model()
The output is:
torch.Size([5])
torch.Size([5, 2])
torch.Size([2])
I think the last shape should be [5,2,2], but I get [2].
Is there any bug in my code?
Please give some suggestions.
Thanks!
fritzo
2
This is an issue with the torch.daig()
interface not supporting broadcasting. I recommend avoiding torch.diag()
.
- When
x
is a vector, torch.diag(x)
is a matrix whose diagonal is x
.
- When
x
is a matrix, torch.diag(x)
is a vector which is the diagonal of x
.
Thus for vectors torch.diag(torch.diag(x)) == x
.
Hi, thanks for you reply.
I replaced torch.diag()
with torch.diagflat()
. But it generate the variables with the shape torch.Size([10, 10])
,which is different from what I want.
How can I generate variable with the shape [5,2,2], which means for each independent sample, a diagonal matrix with shape [2,2] is generated?
Thanks!
@FreeAJust You can use the following trick
x = torch.rand(5, 2)
y = x.new_zeros((5, 2, 2))
y.view(5, -1)[..., ::2+1] = x
1 Like