N-Dimensional Tensor sampling


Say I have a probability tensor T, which is shape (Batch, N, D) and is normalized along the last dimension. Ideally, I want to sample along this last dimension such that I can index into a variable with the same shape as T.

I have been trying to do this using pyro.distributions.Categorical to no avail. Would anyone be willing to point out what I’m currently doing wrong?

T = torch.randn(128,4,5)
T = ptsoftmax(T,-1) — This is just a vanilla softmax function. Double checked that it works as intended.
C = pyro.distributions.Categorical(ps=torch.autograd.Variable(T))

This results in:
The expanded size of the tensor (4) must match the existing size (128) at non-singleton dimension 1. at /pytorch/torch/lib/TH/generic/THTensor.c:308

I have also tried creating C such that C = pyro.distributions.Categorical(ps=torch.autograd.Variable(T), vs=torch.autograd.Variable(T) )
to the same success.

Thank you very much.

Hi ktf,
Currently most of the distributions in Pyro will only support batching along a single dimension (dimension 0). The distribution classes are a light wrapper over the corresponding PyTorch/scipy implementations and any limitations in the underlying implementations are therefore inherited (in this case PyTorch’s multinomial implementation). Arbitrary batch shapes may work on some distributions, but that would be more by accident than design. This is definitely something that we should add to the documentation, and validate in the constructor.

We will support arbitrary batch dimension shape in a future release (likely the next one). We made an attempt to do this for Categorical and Bernoulli (I think scoring using batch_log_pdf will likely work as expected, but the sampling will not), but decided to not pursue this for the first release. Thanks for reporting this. This behavior should be better documented.

PS - In your case, you would have to generate a batch of size 128 x 4, and reshape the distribution sample later.

Okay, thank you very much for your reply! I look forward to the next release.