import logging, os, torch, pyro
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pyro.optim as optim
import pyro.distributions as dist
from torch import nn
from torch.distributions import constraints
from functools import partial
from pyro.nn import PyroModule, PyroSample
from pyro.infer import Predictive
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
pyro.set_rng_seed(1)
assert pyro.__version__.startswith('1.3.0')
plt.style.use('default')
logging.basicConfig(format='%(message)s', level=logging.INFO)
pyro.enable_validation(True)
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
rugged_data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
def model(is_cont_africa, ruggedness, log_gdp):
a = pyro.sample("a", dist.Normal(0., 10.))
b_a = pyro.sample("bA", dist.Normal(0., 1.))
b_r = pyro.sample("bR", dist.Normal(0., 1.))
b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
sigma = pyro.sample("sigma", dist.Uniform(8.0, 10.))
mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
with pyro.plate("data", len(ruggedness)):
pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)
def guide(is_cont_africa, ruggedness, log_gdp):
a_loc = pyro.param('a_loc', torch.tensor(0.))
a_scale = pyro.param('a_scale', torch.tensor(1.), constraint=constraints.positive)
sigma_loc = pyro.param('sigma_loc', torch.tensor(1.), constraint=constraints.positive)
weights_loc = pyro.param('weights_loc', torch.randn(3))
weights_scale = pyro.param('weights_scale', torch.ones(3), constraint=constraints.positive)
# sigma = pyro.param("sigma", constraint=torch.distributions.constraint.positive)
a = pyro.sample("a", dist.Normal(a_loc, a_scale))
b_a = pyro.sample("bA", dist.Normal(weights_loc[0], weights_scale[0]))
b_r = pyro.sample("bR", dist.Normal(weights_loc[1], weights_scale[1]))
b_ar = pyro.sample("bAR", dist.Normal(weights_loc[2], weights_scale[2]))
sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
def summary(samples):
site_stats = {}
for site_name, values in samples.items():
marginal_site = pd.DataFrame(values)
describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
return site_stats
df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)
svi = SVI(model, guide, optim.Adam({"lr": .05}),loss=Trace_ELBO())
is_cont_africa, ruggedness, log_gdp1 = train[:, 0], train[:, 1], train[:, 2]
pyro.clear_param_store()
num_iters = 2
for i in range(num_iters):
elbo = svi.step(is_cont_africa, ruggedness, log_gdp1)
if i % 500 == 0:
logging.info("Elbo loss: {}".format(elbo))
Thank you very much for your reply. The program runs to elbo = svi.step(is_cont_africa, ruggedness, log_gdp1). trace_struct.py indicates an error at def compute_log_prob(self, site_filter=lambda name, site: True)