Pyro.plate error with torch.diag()

Hi, It seems there are some trouble when using toch.diag function in pyro.plate.

just like this:

val = t.ones([2,2])
def model():
    with pyro.plate('dim',5) :
        lv_z = pyro.sample("lv_z", dist.Categorical(t.tensor([0.5,0.5])))
        print(lv_z.shape)
        print(val[lv_z].shape)
        print(t.diag(val[lv_z]).shape)

model()

The output is:

torch.Size([5])
torch.Size([5, 2])
torch.Size([2])

I think the last shape should be [5,2,2], but I get [2].

Is there any bug in my code?

Please give some suggestions.

Thanks!

This is an issue with the torch.daig() interface not supporting broadcasting. I recommend avoiding torch.diag().

  • When x is a vector, torch.diag(x) is a matrix whose diagonal is x.
  • When x is a matrix, torch.diag(x) is a vector which is the diagonal of x.

Thus for vectors torch.diag(torch.diag(x)) == x.

Hi, thanks for you reply.

I replaced torch.diag() with torch.diagflat(). But it generate the variables with the shape torch.Size([10, 10]) ,which is different from what I want.

How can I generate variable with the shape [5,2,2], which means for each independent sample, a diagonal matrix with shape [2,2] is generated?

Thanks!

@FreeAJust You can use the following trick

x = torch.rand(5, 2)
y = x.new_zeros((5, 2, 2))
y.view(5, -1)[..., ::2+1] = x
1 Like

Thanks! It helps a lot!