this seems to work. though personally i haven’t found the psis diagnostic all that useful. for most problems where you want to use variational inference for scalability reasons the posterior approximation is mediocre at best and the psis diagnostic will be bad. if you really care about high quality posterior estimates you should do mcmc
import torch
import pyro
import pyro.distributions as dist
import numpy as np
import pandas as pd
from torch import Tensor
from dataclasses import dataclass
from typing import List, Dict, Optional, Union, Any
from pyro.infer.autoguide import init_to_mean
pyro.enable_validation(True)
@dataclass
class PriorParams:
"""Class for specifying the parameters of the prior distributions used in the linear regression model."""
# Regression coefficient priors.
weight_loc: float = 0.0
weight_scale: float = 1.0
# Bias term prior.
bias_loc: float = 0.0
bias_scale: float = 1.0
# Observation noise prior.
sigma_scale: float = 10.0
def linear_regression_model(x: torch.Tensor, y: torch.Tensor = None, prior_params: PriorParams = None) -> torch.Tensor:
"""Implements a Bayesian model for linear regression that puts priors on the model parameters and observation noise.
Parameters
----------
x
The input feature matrix of shape: (num_examples, num_features).
y
The response/target variable vector of size: num_examples.
prior_params
A 'dataclass' that stores the parameter values of the prior distributions used in the model.
Note: If `prior_params` is None, then the default values from the class are used.
Returns
-------
mean
The regression line.
"""
if prior_params is None:
prior_params = PriorParams()
# Regression coefficient priors.
weight = pyro.sample(
'weight', dist.Normal(prior_params.weight_loc, prior_params.weight_scale).expand([x.shape[1]]).to_event(1))
# Bias term prior (aka the intercept or independent coefficient).
bias = pyro.sample('bias', dist.Normal(prior_params.bias_loc, prior_params.bias_scale).expand([1]).to_event(1))
# Observation noise prior (aka the random error in the response variable 'y').
sigma = pyro.sample('sigma', dist.HalfNormal(scale=prior_params.sigma_scale).expand([1]).to_event(1))
# Calculate the expected mean. This is a linear combination of the input features, weight matrix and bias vector.
mean = torch.matmul(x, weight.unsqueeze(-1)).squeeze(-1) + bias
# Plate context manager (vectorized).
pyro.sample('obs', dist.Normal(mean, sigma).to_event(1), obs=y)
return mean
class BayesianRegressionModel:
def __init__(self, random_state: int = 42):
"""Bayesian Regression Model.
Attributes
----------
random_state
Seed used by the random number generator. This allows for repeatable sampling procedures.
"""
# Set the seed used by the random number generator.
self.random_state = random_state
# Model evaluation metrics.
self.training_losses: Optional[List[float]] = None
# Pyro-specific attributes.
self._model = None
self._guide = None
def fit(self,
x_data: torch.Tensor,
y_data: torch.Tensor,
num_iterations: int = 1000,
learning_rate: float = 0.001) -> None:
"""Fits the Bayesian Regression Model.
Parameters
----------
x_data
Tensor of input features.
y_data
Tensor of observation values.
num_iterations
The number of iterations to run stochastic variational inference (SVI) for.
learning_rate
The learning rate used for optimizing the Evidence Lower Bound (ELBO) objective.
Returns
-------
None
"""
# Reset the random number generator.
pyro.set_rng_seed(self.random_state)
# Linear Regression Model.
self._model = linear_regression_model
# Pyro Guide.
self._guide = pyro.infer.autoguide.AutoNormal(self._model, init_loc_fn=init_to_mean)
# Variational Inference.
pyro.clear_param_store()
# ELBO.
elbo = pyro.infer.Trace_ELBO()
# Stochastic Variational Inference (SVI).
optimizer = pyro.optim.Adam({'lr': learning_rate})
svi = pyro.infer.SVI(self._model, self._guide, optimizer, elbo)
# Run Stochastic Gradient Descent.
self.training_losses = []
for step in range(num_iterations):
# Calculate the loss and take a gradient step.
loss = svi.step(x_data, y_data)
self.training_losses.append(loss)
if step % 100 == 0:
print(f'ELBO loss [iteration {step + 1}] loss: {loss:.4f}')
# Examine the optimized guide parameter values from the trained guide by fetching from Pyro’s param store.
self._guide.requires_grad_(False)
for name, value in pyro.get_param_store().items():
print(f'{name}: {pyro.param(name).data.cpu().numpy()}')
def predict(self,
x_data: torch.Tensor,
num_samples: int = 800) -> Dict[Any, Dict[str, Union[Tensor, Any]]]:
"""Makes predictions for new observations.
Parameters
----------
x_data
Tensor of input features.
num_samples
The number of samples to generate from the trained model for estimating the regression line and posterior
predictive distribution.
Returns
-------
pred_summary
DataFrame including predictions from the fitted regression line (mean estimate) as well as estimated
percentiles from the posterior predictive distribution.
"""
# Reset the random number generator.
pyro.set_rng_seed(self.random_state)
# Generate Samples from the Posterior Predictive Distribution.
predictive = pyro.infer.Predictive(
self._model, guide=self._guide, num_samples=num_samples, return_sites=('obs', '_RETURN'), parallel=True)
posterior_samples = predictive(x_data)
pred_summary = self.summary(posterior_samples)
return pred_summary
@staticmethod
def summary(samples) -> Dict[Any, Dict[str, Union[Tensor, Any]]]:
site_stats = {}
for k, v in samples.items():
site_stats[k] = {
"mean": torch.mean(v, 0),
"std": torch.std(v, 0),
"5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
"95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
}
return site_stats
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
df["cont_africa_x_rugged"] = df["cont_africa"] * df["rugged"]
data = torch.tensor(df[["cont_africa", "rugged", "cont_africa_x_rugged", "rgdppc_2000"]].values, dtype=torch.float)
x_data, y_data = data[:, :-1], data[:, -1]
blr = BayesianRegressionModel()
blr.fit(x_data, y_data)
y_pred = blr.predict(x_data)
k_hat = pyro.infer.importance.psis_diagnostic(blr._model, blr._guide, x_data, y_data, num_particles=500000,
max_plate_nesting=0)
print("k_hat", k_hat)