- What tutorial are you running? Bayesian regression (Part I)
- What version of Pyro are you using? 0.3.3
Hi, I’ve been following this tutorial to implement a Bayesian nnet in Pyro, and I’m being able to follow it till the prediction step, where I get a bit confused about the sampling pipeline; in particular, my questions are:
Considering the model and the evaluation code below:
def model(x_data, y_data): # ... I'm commenting the prior definition out lifted_module = pyro.random_module("module", regression_model, priors) lifted_reg_model = lifted_module() with pyro.plate("map", len(x_data)): prediction_mean = lifted_reg_model(x_data).squeeze(-1) pyro.sample("obs", Normal(prediction_mean, scale), obs=y_data) return prediction_mean def evaluate_model(svi, x_data, y_data): posterior = svi.run(x_data, y_data) post_pred = TracePredictive(wrapped_model, posterior, num_samples=1000).run(x_data, None) marginal = EmpiricalMarginal(post_pred, ['obs'])._get_samples_and_weights().detach().cpu().numpy()
svi.run(x_data, y_data), Pyro is internally executing:
for i in range(self.num_samples): guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs) yield model_trace, 1.0
I understand that
poutine.trace(self.guide).get_trace(*args, **kwargs) is the way to sample from the trained guide (namely, the guide’s params), to be later used in the joint distribution (i.e. the model) via
poutine.trace(poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs).
a) Is what I’m saying correct?
b) In this last call, because I’m passing the
x, y testing data to
model(x_data, y_data), isn’t that gonna condition the likelihood to the
y testing data (due to
pyro.sample("obs", Normal(prediction_mean, scale), obs=y_data) ?
After having called
svi.run(...)I should have already got my posterior traces. Why do I need to call
TracePredictive(...).run(x_data, None), which is internally calling
resampled_trace = poutine.trace(poutine.replay(self.model, model_trace)).get_trace(*args, **kwargs)again ?
Lastly, I see in the code that at several points there is a
Deltadistribution being instantiated with some values. Is this a trick used so that you can get these original values from a distribution object? That is, defining a distribution that has only 1 value, that can be later sampled from?