PyroSample Naming

Hi all,

I’ve been exploring the use of PyroModule / PyroSample, and ran into what seems like a simple problem with naming. Below is some model code for a multinomial logit model, where I’m using PyroModule[nn.Linear] multiple times to create alternative-specific logits:

class MNL(PyroModule):
    def __init__(self, nalts, nfeatures):
        super().__init__()
        self.nalts = nalts
        self.nfeatures = nfeatures
        in_features = nfeatures
        out_features = 1
        self.utils = [PyroModule[nn.Linear](in_features, out_features) for j in range(nalts)]
        self.utils[0].weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.utils[0].bias = None
        for j in range(1,nalts):
            self.utils[j].weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
            self.utils[j].bias = PyroSample(dist.Normal(0., 1.).expand([out_features]).to_event(1))

def forward(self, x, y=None):
    choice_logits = torch.zeros((x.shape[0], self.nalts))
    for j in range(self.nalts):
        choice_logits[:,j] = self.utils[j](x[:,:,j]).squeeze()
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.Categorical(logits = choice_logits), obs=y)
    return nn.functional.softmax(choice_logits)

When I run this model with AutoGuide, I get the error that multiple sample sites are named “weight,” which I guess is true, since there are multiple nn.Linears. Is there an easy solution to this, like a way of assigning names to PyroSample, or should I avoid using nn.Linear in this case?

(This question was also asked here, but I didn’t see a reply.)

Thanks!

Hi @rtdew1, you’ll need to wrap your list of modules in a PyroModule[ModuleList] :

self.utils = PyroModule[torch.nn.ModuleList]([
    PyroModule[nn.Linear](in_features, out_features)
    for j in range(nalts)
])  

Thanks very much, very helpful!