Tracking latent variable values

Hello, I have a pyro module shown below where the particle_transformer object is a PyroModule BNN and histogram is a kernel density estimator function written in pytorch (maintaining differentiability).

class ProjectionModel(PyroModule):
    def __init__(self, particle_transformer, bins, tkwargs=None):
        super().__init__()

        tkwargs = tkwargs or {}

        self.particle_transformer = particle_transformer
        self.register_buffer("bins", bins)
        self.register_buffer("bandwidth", 0.5*(bins[1] - bins[0]))
        self.register_buffer("sigma", torch.tensor(10, **tkwargs))

        self.register_buffer("x",torch.randn((10000,2),**tkwargs))

    def generate_dist(self):
        return self.particle_transformer(self.x)

    def forward(self, y=None):
        z = self.generate_dist()

        projection = histogram(z[:,0].T, self.bins, self.bandwidth)

        if y is not None:
            with pyro.plate("data", y.shape[0]):
                obs = pyro.sample("obs", dist.Normal(projection, self.sigma**2), obs=y)

        return projection

I am able to train this module effectively using pyro.infer.SVI and make predictions of the module output using pyro.infer.Predictive. However, I would like to call the generate_dist method using the posterior guide parameters of the BNN, or at least get posterior samples of the latent variable z. What is the best way to do this?