I see, I believe you should be able to use the deferred version of PyroSample
, as in e.g. this thread on PyroModule
:
class BNN(PyroModule):
def __init__(self, input_size, hidden_size, output_size):
super(BNN, self).__init__()
self.fc1 = PyroModule[nn.Linear](input_size, hidden_size)
self.out = PyroModule[nn.Linear](hidden_size, output_size)
self.fc1.weight_scale = PyroSample(dist.InverseGamma(...))
# pass this PyroSample a lambda that takes self.fc1 and generates a Distribution
self.fc1.weight = PyroSample(lambda self: dist.Normal(0., self.weight_scale).expand([hidden_size, input_size]).to_event(2))
... # handle other parameters similarly
def forward(self, x, y_data=None):
output = self.fc1(x)
output = F.relu(output)
output = self.out(output)
lhat = F.log_softmax(output)
obs = pyro.sample("obs", dist.Categorical(logits=lhat), obs=y_data)
return obs