Byaesian Regression with batch data

I am following the pyro example, Bayesian Regression 1. Instead of throwing the data all at once, I use Pytorch Dataloader to load the data. For Trace_ELBO() loss, I am specifying the size and sub_sample size parameter. The problem runs smoothly except when I plot Posterior predictive distribution with 90% CI I get a straight line and not an interval. I checked and the obs variable in samples is not changing at all. What might be the reason for this, and how to fix this?

Here size of the dataset is 170. I have tried 8 and 170 as the batch_size/sub_sample size, and both give the same issue. If I understand this correctly, batch_size of 170 should bring the problem back to the one mentioned in the pyro example, but when I do it the results do not match.

Code here:

class gdpDataset(Dataset):
    def __init__(self,df, x_cols=["cont_africa","rugged","cont_africa_x_rugged"], y_col="rgdppc_2000"):
        self.X = torch.tensor(df[x_cols].values)
        self.y = torch.tensor(df[y_col].values)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self,idx):
        return self.X[idx], self.y[idx]
    
dataset_torch = gdpDataset(df)
dataloader_torch = DataLoader(dataset_torch, shuffle=True, batch_size=170)
class BayesianRegression(PyroModule):
    def __init__(self):
        super().__init__()
        self.l1 = PyroModule[nn.Linear](3,1)
        self.l1.weight = PyroSample(dist.Normal(0.,1.).expand([1,3]).to_event(2))
        self.l1.bias = PyroSample(dist.Normal(0.,10.).expand([1]).to_event(1))
        
    def forward(self,x,y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0.,10.))
        mean = self.l1(x).squeeze(-1)
        with pyro.plate("data", size=170, subsample_size=x.shape[0]) as ind:
            obs = pyro.sample("obs",dist.Normal(mean, sigma), obs=y)
        return mean
model = BayesianRegression()
guide = AutoDiagonalNormal(model)
optim = pyro.optim.Adam({"lr":0.03})
loss_fn = Trace_ELBO()
num_epochs = 1500
svi = SVI(model,guide, optim, loss=loss_fn)

def train():
    loss = 0
    for x_data,y_data in dataloader_torch:
        loss_step = svi.step(x_data.float(), y_data.float())
        loss+=loss_step
    return loss
pyro.clear_param_store()
for j in range(num_epochs):
    # calculate the loss and take a gradient step
    loss = train()
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
predictive = Predictive(model, guide=guide, num_samples=800, return_sites=("l1.weight", "obs", "_RETURN"))

samples = predictive(torch.tensor(df[["cont_africa","rugged","cont_africa_x_rugged"]].values).float(), \
                     torch.tensor(df["rgdppc_2000"].values).float())

def summary(samples):
    site_stats = {}
    for k,v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v,0),
            "std": torch.std(v,0),
            "5%": v.kthvalue(int(len(v)*0.05), dim=0)[0],
            "95%": v.kthvalue(int(len(v)*0.95),dim=0)[0]
        }
    return site_stats

pred_summary = summary(samples)
mu = pred_summary["_RETURN"]
y = pred_summary["obs"]
predictions = pd.DataFrame({
    "cont_africa": df["cont_africa"],
    "rugged": df["rugged"],
    "mu_mean": mu["mean"],
    "mu_perc_5": mu["5%"],
    "mu_perc_95": mu["95%"],
    "y_mean": y["mean"],
    "y_perc_5": y["5%"],
    "y_perc_95": y["95%"],
    "true_gdp": df["rgdppc_2000"],
})

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = predictions[predictions["cont_africa"] == 1]
non_african_nations = predictions[predictions["cont_africa"] == 0]
african_nations = african_nations.sort_values(by=["rugged"])
non_african_nations = non_african_nations.sort_values(by=["rugged"])
fig.suptitle("Posterior predictive distribution with 90% CI", fontsize=16)
ax[0].plot(non_african_nations["rugged"],
           non_african_nations["mu_mean"])
ax[0].fill_between(non_african_nations["rugged"],
                   non_african_nations["y_perc_5"],
                   non_african_nations["y_perc_95"],
                   alpha=0.5)
ax[0].plot(non_african_nations["rugged"],
           non_african_nations["true_gdp"],
           "o")
ax[0].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="Non African Nations")
idx = np.argsort(african_nations["rugged"])
ax[1].plot(african_nations["rugged"],
           african_nations["mu_mean"])
ax[1].fill_between(african_nations["rugged"],
                   african_nations["y_perc_5"],
                   african_nations["y_perc_95"],
                   alpha=0.5)
ax[1].plot(african_nations["rugged"],
           african_nations["true_gdp"],
           "o")
ax[1].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="African Nations");

1 Like