I’m working through part 1 of the Bayesian linear regression tutorial. I use the AutoDiagNormal guide.
But I am having problems getting posterior samples from this, for the sigma parameter,
which has a U(0,10) prior, and hence is constrained.
I had a similar problem in numpyro, which was solved as follows:
svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), data=data)
svi_result = svi.run(random.PRNGKey(0), 2000)
post = guide.get_posterior(svi_result.params)
unconstrained_samples = post.sample(rng_key, sample_shape=(nsamples,))
constrained_samples = guide._unpack_and_constrain(unconstrained_samples, svi_result.params)
I want to do something similar in Pyro, but cannot find the equivalent of the hidden _unpack_and_constrain
function.
More specifically, referring to the linear regression tutorial, I noticed
that the posterior marginal quantiles are transformed properly (sigma is 0.9182, which matches HMC):
quant = guide.quantiles([0.5])
print(quant)
# {'sigma': [tensor(0.9182)], 'linear.weight': [tensor([[-1.8635, -0.1926, 0.3305]])], 'linear.bias': [tensor([9.1682])]}
But when I sample from the posterior, they are not transformed (mean of sigma samples is -2.2926):
post = guide.get_posterior()
nsamples = 800
samples = post.sample(sample_shape=(nsamples,))
print(torch.mean(samples,dim=0)) # E[transform(sigma)]=-2.2926,
#tensor([-2.2926, -1.8662, -0.1933, 0.3319, 9.1655])
IIUC, a scale parameter s with prior U(0,B) is transformed to unconstrained t=logit(s/B),
so we can back transform using s=sigmoid(t)*B. This works: sigmoid(-2.2926)*10 = 0.9174.
But how do I do this programmatically? This is what I tried, but neither make sense to me…
transform = guide.get_transform()
trans_samples = transform(samples)
print(torch.mean(trans_samples,dim=0))
#tensor([-2.4197, -2.1300, -0.2015, 0.3586, 9.7502])
trans_samples = transform.inv(samples)
print(torch.mean(trans_samples,dim=0))
# tensor([-0.0175, -0.0195, -0.0148, 0.0158, -0.0411])
In general, it would be great to have a guide.sample_posterior
function for autoguides in Pyro that matches the one in NumPyro, and which gives the same kind of output as HMC.