So I can convert these to vectors, stack them on top of each other and then provide this to dist.Categorical()
That's right - you can do
torch.stack([a, b, c]) which should give you a tensor of size
(510*3, 10, 3). I am guessing that your trailing dim has size 3 to account for each of the weight values.
by default for three parameter categorical pyro returns 0,1 and 2. However, i want -1,0,1. so i could do 1-dist.Categorical(a,b,c).sample()
You probably want to do
1 - dist.Categorical(..).sample(). By itself,
dist.Categorical(..) only gives you a distribution instance. In pyro you can do something like
values = 1 - pyro.sample("cat", dist.Categorical(weights)) which will do the sampling behind the scenes.
If you have any distribution specific questions, I would suggest bringing them over to the PyTorch distributions channel. You'll likely get a faster response there.