Hello, I am trying to perform linear regression in the probabilistic framework. Since the data is high-dimensional, I am not doing a Bayesian inference but rather SVI. Initially, I was interested in getting only point estimates, so I was using AutoDelta. Now, I am interested in getting Confidence Intervals, so I have started experimenting with AutoNormal and AutoMultivariateNormal. AutoMultivariate gives accurate predictions compared to AutoNormal. However, AutoDelta gives the best results. What could be the reason behind this? Am I writing the code in the right manner? How do I determine which guide to use?
guide = AutoDelta(model)
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_current = svi.init(rng_key=rng, dat=holdData, enr_obs=np.array(holdData['DV']))
.
.
.
.
final_iter_params = svi.get_params(svi_current)
posterior_samples = guide.sample_posterior(rng, final_iter_params, sample_shape=(1,))
posterior_estimates = {k: np.array(samples.mean(axis=0)) for k, samples in posterior_samples.items()}
param_output = {}
for key, value in posterior_estimates.items():
try:
param_output[f'{key}_auto_loc'] = posterior_estimates[key].reshape(1, -1)[0]
except KeyError:
break
guide = AutoNormal(model)
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_current = svi.init(rng_key=rng, dat=holdData, enr_obs=np.array(holdData['DV']))
.
.
.
.
final_iter_params = svi.get_params(svi_current)
posterior_samples = guide.sample_posterior(rng, final_iter_params, sample_shape=(1,))
posterior_quantiles = guide.quantiles(final_iter_params, [0.05, 0.5, 0.95])
posterior_estimates = {k: np.array(samples.mean(axis=0)) for k, samples in posterior_samples.items()}
param_output = {}
for key, value in posterior_estimates.items():
try:
param_output[f'{key}_auto_loc'] = posterior_estimates[key].reshape(1, -1)[0]
except KeyError:
break
for key, quantile_values in posterior_quantiles.items():
try:
# Extract 5th, 50th (median), and 95th percentiles
q05 = quantile_values[0].reshape(-1)
q50 = quantile_values[1].reshape(-1)
q95 = quantile_values[2].reshape(-1)
param_output[f'{key}_5th_percentile'] = q05
param_output[f'{key}_median'] = q50
param_output[f'{key}_95th_percentile'] = q95
except KeyError:
break
``'