Hi, It seems there are some trouble when using toch.diag function in pyro.plate.

just like this:

```
val = t.ones([2,2])
def model():
with pyro.plate('dim',5) :
lv_z = pyro.sample("lv_z", dist.Categorical(t.tensor([0.5,0.5])))
print(lv_z.shape)
print(val[lv_z].shape)
print(t.diag(val[lv_z]).shape)
model()
```

The output is:

```
torch.Size([5])
torch.Size([5, 2])
torch.Size([2])
```

I think the last shape should be [5,2,2], but I get [2].

Is there any bug in my code?

Please give some suggestions.

Thanks!