Pyro.sample from a pytroch gumbel softmax

sampled_indexes = pyro.sample(f"{address}_{index}", F.gumbel_softmax(char_dist, hard=True, dim=2),obs=observed[index]).squeeze(0)

this is throwing an error from Pyro, but I thought Pyro supports Pytorch distributions as well. I was recommended using RelaxedOneHotCategorical instead but that doesn’t let me choose the dimension over which to softmax over cause the char_dist is a 1 x batch_size x categories tensor so I need it to softmax over the 2nd dim.