Hi all,
I’ve been exploring the use of PyroModule / PyroSample, and ran into what seems like a simple problem with naming. Below is some model code for a multinomial logit model, where I’m using PyroModule[nn.Linear] multiple times to create alternative-specific logits:
class MNL(PyroModule):
def __init__(self, nalts, nfeatures):
super().__init__()
self.nalts = nalts
self.nfeatures = nfeatures
in_features = nfeatures
out_features = 1
self.utils = [PyroModule[nn.Linear](in_features, out_features) for j in range(nalts)]
self.utils[0].weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.utils[0].bias = None
for j in range(1,nalts):
self.utils[j].weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.utils[j].bias = PyroSample(dist.Normal(0., 1.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
choice_logits = torch.zeros((x.shape[0], self.nalts))
for j in range(self.nalts):
choice_logits[:,j] = self.utils[j](x[:,:,j]).squeeze()
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Categorical(logits = choice_logits), obs=y)
return nn.functional.softmax(choice_logits)
When I run this model with AutoGuide, I get the error that multiple sample sites are named “weight,” which I guess is true, since there are multiple nn.Linears. Is there an easy solution to this, like a way of assigning names to PyroSample, or should I avoid using nn.Linear in this case?
(This question was also asked here, but I didn’t see a reply.)
Thanks!