Generating output of Pyro neural network model based on the specified prior distribution

I have converted a part of my frequentist neural network model (my_frequentist_model) into a Pyro model.

my_frequentist_model consists of main body and an output head, and I only applied Pyro priors to the weights of the output head, since I do not want the weights of the main body to be changed (the main body weights are pre-trained). I used the following code to achieve this:

# convert our `model_gpt2` into a pyro module.

# convert the output head of `my_frequentist_model` into a
# Bayesian layer.
for m in my_frequentist_model.output_head.modules():
     for name, value in list(m.named_parameters(recurse=False)):
            setattr(m, name, module.PyroSample(prior=dist.Normal(0, 1)

Now, after I execute the code above, if I do my_frequentist_model(my_input), would Pyro give the output of the my_frequentist_model based on the new output head weights which are drawn from their prior distribution (dist.Normal(0,1))? or would the command my_frequentist_model(my_input) generate the output based on the old frequentist output head weights?

Thank you,

I believe so (EDIT that the weights would be drawn from the prior). Note that you can verify that weights are being sampled by using pyro.poutine.trace to record the sample sites that appear during execution of the model:

with pyro.poutine.trace() as tr:
    my_output = my_frequentist_model(my_input)

for name, node in tr.trace.nodes.items():
    if node["type"] == "sample":

Sorry, got slightly confused, so does this mean that output of the my_frequentist_model(my_input) is based on the new output head weights which are drawn from the prior distribution ( dist.Normal(0,1)) that I specified?
:S thank you,

It looks like you’re using the API correctly to me, but it’s hard to be conclusive about behavior of code I can’t run. I suggest using the snippet in my previous post to verify for yourself that sample sites for weights appear in the trace as expected, as well as using a debugger to inspect the behavior of your code as it runs.


After running the snippet, I get the following output:


so what does this output imply…? Thank you :S

what does this output imply…?

It means that the parameters weight and bias of output_head are being sampled from the priors you set for them, and that the output of output_head is computed using those samples.

1 Like

Thank you, I appreciate your help.