if you do it directly (without random_module
) then you can do whatever you want - just sample each layer from the same distribution. if you use random_module
, if you want to sample every layer from the same dist my_dist
, you can just pass that dist into random_module
:
my_dist = dist.Normal(...)
lifted_module = pyro.random_module('random nn', nn_module, my_dist)
if you only want certain layers lifted, this should work:
priors = {'layer_1': dist_1, 'layer_2': dist_1, 'layer_3': dist_2}
lifted_module = pyro.random_module('random nn', nn_module, priors)