I’ve been exploring the use of PyroModule / PyroSample, and ran into what seems like a simple problem with naming. Below is some model code for a multinomial logit model, where I’m using PyroModule[nn.Linear] multiple times to create alternative-specific logits:
class MNL(PyroModule): def __init__(self, nalts, nfeatures): super().__init__() self.nalts = nalts self.nfeatures = nfeatures in_features = nfeatures out_features = 1 self.utils = [PyroModule[nn.Linear](in_features, out_features) for j in range(nalts)] self.utils.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)) self.utils.bias = None for j in range(1,nalts): self.utils[j].weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)) self.utils[j].bias = PyroSample(dist.Normal(0., 1.).expand([out_features]).to_event(1)) def forward(self, x, y=None): choice_logits = torch.zeros((x.shape, self.nalts)) for j in range(self.nalts): choice_logits[:,j] = self.utils[j](x[:,:,j]).squeeze() with pyro.plate("data", x.shape): obs = pyro.sample("obs", dist.Categorical(logits = choice_logits), obs=y) return nn.functional.softmax(choice_logits)
When I run this model with AutoGuide, I get the error that multiple sample sites are named “weight,” which I guess is true, since there are multiple nn.Linears. Is there an easy solution to this, like a way of assigning names to PyroSample, or should I avoid using nn.Linear in this case?
(This question was also asked here, but I didn’t see a reply.)