I am still rather new to pyro, so getting data into the model is still a work in progress. I should also clarify that I am thinking of using Pyro for learning on PGMs. But I was wondering about having
a) multivariate data in the model
b) multivariate outputs.
I was looking at the Bayesian regression tutorial, and just wanted to make sure I understood some of the setup. So for having multivariate data, I can just add those additional columns to the model function arguments, and then use them as in the example below. HOWEVER, the
log_gdp fields are not indexed as the outcome variable, so that seems a bit confusing. Seems like the
obs is reserved for the outcome variable. So I just wanted to check this first part.
def model(is_cont_africa, ruggedness, log_gdp): a = pyro.sample("a", dist.Normal(0., 10.)) b_a = pyro.sample("bA", dist.Normal(0., 1.)) b_r = pyro.sample("bR", dist.Normal(0., 1.)) b_ar = pyro.sample("bAR", dist.Normal(0., 1.)) sigma = pyro.sample("sigma", dist.Uniform(0., 10.)) mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness with pyro.plate("data", len(ruggedness)): pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)
The thing I was not clear about was how to handle multivariate outputs? So it is as simple as just having an additional
pyro.sample(, obs=...) in the
pyro.plate, or would that create different problems. Do I need to do anything fancy with the variable naming and indexing to keep things aligned?
If there is a good example of this, please let me know. I am working my way through the tutorials, but have not hit them all yet.