Hyperprior

I see, I believe you should be able to use the deferred version of PyroSample, as in e.g. this thread on PyroModule:

class BNN(PyroModule):

def __init__(self, input_size, hidden_size, output_size):
    super(BNN, self).__init__()
    self.fc1 = PyroModule[nn.Linear](input_size, hidden_size)
    self.out = PyroModule[nn.Linear](hidden_size, output_size)
    self.fc1.weight_scale = PyroSample(dist.InverseGamma(...))
    # pass this PyroSample a lambda that takes self.fc1 and generates a Distribution
    self.fc1.weight = PyroSample(lambda self: dist.Normal(0., self.weight_scale).expand([hidden_size, input_size]).to_event(2))
    ...  # handle other parameters similarly
    
def forward(self, x, y_data=None):
    output = self.fc1(x)
    output = F.relu(output)
    output = self.out(output)
    lhat = F.log_softmax(output)
    obs = pyro.sample("obs", dist.Categorical(logits=lhat), obs=y_data)
    return obs
1 Like