 # Pyro.plate error with torch.diag()

#1

Hi, It seems there are some trouble when using toch.diag function in pyro.plate.

just like this:

``````val = t.ones([2,2])
def model():
with pyro.plate('dim',5) :
lv_z = pyro.sample("lv_z", dist.Categorical(t.tensor([0.5,0.5])))
print(lv_z.shape)
print(val[lv_z].shape)
print(t.diag(val[lv_z]).shape)

model()
``````

The output is:

``````torch.Size()
torch.Size([5, 2])
torch.Size()
``````

I think the last shape should be [5,2,2], but I get .

Is there any bug in my code?

Please give some suggestions.

Thanks!

#2

This is an issue with the `torch.daig()` interface not supporting broadcasting. I recommend avoiding `torch.diag()`.

• When `x` is a vector, `torch.diag(x)` is a matrix whose diagonal is `x`.
• When `x` is a matrix, `torch.diag(x)` is a vector which is the diagonal of `x`.

Thus for vectors `torch.diag(torch.diag(x)) == x`.

Gaussian Mixture Model
#3

Hi, thanks for you reply.

I replaced `torch.diag()` with `torch.diagflat()`. But it generate the variables with the shape `torch.Size([10, 10])` ,which is different from what I want.

How can I generate variable with the shape [5,2,2], which means for each independent sample, a diagonal matrix with shape [2,2] is generated?

Thanks!

#4

@FreeAJust You can use the following trick

``````x = torch.rand(5, 2)
y = x.new_zeros((5, 2, 2))
y.view(5, -1)[..., ::2+1] = x
``````

#5

Thanks! It helps a lot!