Hello, I have a pyro module shown below where the particle_transformer
object is a PyroModule BNN and histogram
is a kernel density estimator function written in pytorch (maintaining differentiability).
class ProjectionModel(PyroModule):
def __init__(self, particle_transformer, bins, tkwargs=None):
super().__init__()
tkwargs = tkwargs or {}
self.particle_transformer = particle_transformer
self.register_buffer("bins", bins)
self.register_buffer("bandwidth", 0.5*(bins[1] - bins[0]))
self.register_buffer("sigma", torch.tensor(10, **tkwargs))
self.register_buffer("x",torch.randn((10000,2),**tkwargs))
def generate_dist(self):
return self.particle_transformer(self.x)
def forward(self, y=None):
z = self.generate_dist()
projection = histogram(z[:,0].T, self.bins, self.bandwidth)
if y is not None:
with pyro.plate("data", y.shape[0]):
obs = pyro.sample("obs", dist.Normal(projection, self.sigma**2), obs=y)
return projection
I am able to train this module effectively using pyro.infer.SVI and make predictions of the module output using pyro.infer.Predictive. However, I would like to call the generate_dist
method using the posterior guide parameters of the BNN, or at least get posterior samples of the latent variable z
. What is the best way to do this?