Inference for simple DAG with mixed distributions


#1

I’m very new to Pyro, Bayesian models, and VI, despite having worked through many papers and tutorials, so I hope you’ll forgive some very naive questions…

I’m trying to construct a model from which to perform inference, and I’m starting from a small toy situation, in which I have three attributes that are collected: job, sex and income. These are categorical, binary and interval attributes, respectively, and I’m assuming they form a DAG in which sex depends upon job, and income depends upon both sex and job.

I have made an attempt to create a model for this scenario:

import torch
import pyro
import pyro.distributions as dist

# Data is tensor of format
# [[job, sex, income], ...]
cols = {'job':0,
        'sex':1, 
        'income':2}
n_job_categories = 4


def model(data):
    
    # Start off assuming all jobs equally likely
    job_cats = pyro.sample("job_cats", dist.Dirichlet(torch.ones(n_job_categories) / n_job_categories))

    # Sex of person is dependent upon job - but assume no bias initially
    # and limit to simple case of either male or female to provide
    # binary attribute example
    job_sex_alpha = torch.tensor([10., 10.]) 
    job_sex_beta = torch.tensor([10., 10.])

    # Income is spread normally with parameters determined
    # by job and sex
    inc_job_sex_norm = torch.ones([n_job_categories, 2]*10000)
    inc_job_sex_var = torch.ones([n_job_categories, 2]*100)
    
    for i in pyro.plate("workers", len(data)):
        datum = data[i]
        
        job = pyro.sample("obs_job_{}".format(i), 
                                  dist.Categorical(job_cats), 
                                  obs=datum[cols['job']])

        sex_prior = pyro.sample("sex_prior_{}".format(i), dist.Beta(job_sex_alpha[job.long()], job_sex_beta[job.long()]))
        sex = pyro.sample("obs_sex_{}".format(i), 
                                  dist.Bernoulli(sex_prior), 
                                  obs=datum[cols['sex']])

        inc_prior = pyro.sample("inc_prior_{}".format(i), dist.Normal(inc_job_sex_norm[job.long()][sex.long()], inc_job_sex_var[job.long()][sex.long()]))
        sex = pyro.sample("obs_sex_{}".format(i), 
                                  dist.Bernoulli(sex_prior), 
                                  obs=datum[cols['sex']])

I’d greatly appreciate answers to the following:

  1. Is model constructed correctly, in line with the original description?
  2. If I wished to use svi to infer p(job|sex,income) is there a way to automatically generate a guide for this model, or otherwise how should I go about manually constructing a guide for such a mixed-distribution model?

#2

Hi @beldaz, the model looks sensible, except for the last three lines which I assume should be income = pyro.sample(..., dist.Norma(...), obs=...).

I would recommend using discrete enumeration for this model, combined with an AutoDelta guide on the continuous parameters (or AutoDiagNormal or AutoMultivariateNormal once you get AutoDelta working). You should be able to do something like

from pyro.contrib.autoguide import AutoDelta
guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: "prior" in msg["name"]))
svi = SVI(model, guide, Adam(...), TraceEnum_ELBO(...))

Also inference will be much faster if you can use a vectorized pyro.plate; it will be some work to get indexing right, but then model parameters should train quickly. I’d recommend reading the enumeration tutorial and taking a look at some of the hmm examples with multiple latent variables.


#3

Thanks @fritzo. That’s a great help.