I am still rather new to pyro, so getting data into the model is still a work in progress. I should also clarify that I am thinking of using Pyro for learning on PGMs. But I was wondering about having
a) multivariate data in the model
b) multivariate outputs.
I was looking at the Bayesian regression tutorial, and just wanted to make sure I understood some of the setup. So for having multivariate data, I can just add those additional columns to the model function arguments, and then use them as in the example below. HOWEVER, the ruggedness
or log_gdp
fields are not indexed as the outcome variable, so that seems a bit confusing. Seems like the pyro.sample()
with obs
is reserved for the outcome variable. So I just wanted to check this first part.
def model(is_cont_africa, ruggedness, log_gdp):
a = pyro.sample("a", dist.Normal(0., 10.))
b_a = pyro.sample("bA", dist.Normal(0., 1.))
b_r = pyro.sample("bR", dist.Normal(0., 1.))
b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
with pyro.plate("data", len(ruggedness)):
pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)
The thing I was not clear about was how to handle multivariate outputs? So it is as simple as just having an additional pyro.sample(, obs=...)
in the pyro.plate
, or would that create different problems. Do I need to do anything fancy with the variable naming and indexing to keep things aligned?
If there is a good example of this, please let me know. I am working my way through the tutorials, but have not hit them all yet.