Hi, It seems there are some trouble when using toch.diag function in pyro.plate.

just like this:

```
val = t.ones([2,2])
def model():
with pyro.plate('dim',5) :
lv_z = pyro.sample("lv_z", dist.Categorical(t.tensor([0.5,0.5])))
print(lv_z.shape)
print(val[lv_z].shape)
print(t.diag(val[lv_z]).shape)
model()
```

The output is:

```
torch.Size([5])
torch.Size([5, 2])
torch.Size([2])
```

I think the last shape should be [5,2,2], but I get [2].

Is there any bug in my code?

Please give some suggestions.

Thanks!

fritzo
2
This is an issue with the `torch.daig()`

interface not supporting broadcasting. I recommend avoiding `torch.diag()`

.

- When
`x`

is a vector, `torch.diag(x)`

is a matrix whose diagonal is `x`

.
- When
`x`

is a matrix, `torch.diag(x)`

is a vector which is the diagonal of `x`

.

Thus for vectors `torch.diag(torch.diag(x)) == x`

.

Hi, thanks for you reply.

I replaced `torch.diag()`

with `torch.diagflat()`

. But it generate the variables with the shape `torch.Size([10, 10])`

,which is different from what I want.

How can I generate variable with the shape [5,2,2], which means for each independent sample, a diagonal matrix with shape [2,2] is generated?

Thanks!

@FreeAJust You can use the following trick

```
x = torch.rand(5, 2)
y = x.new_zeros((5, 2, 2))
y.view(5, -1)[..., ::2+1] = x
```

1 Like