Model selection between heterogenous models (different families of distributions)

Hi,

I’m not exactly sure what to call the thing that I’m asking for, but I hope it becomes clear.

I want to fit a model to data, where I have multiple candidate models that may explain the data, but those candidate models are of different forms. So for example, I want to find out if the data distribution is better explained by a normal distribution, a gamma distribution, or a more complex, hierarchical model.

Of course I could do model selection “outside” of SVI and compare ELBO between these models after fitting, but I would like to do that directly using SVI and “in” the model. If the different candidate models had the same form, then I could basically use pyro.plate similar to the GMM or HMM examples.

But what’s the best way to select between heterogenous models directly using pryo inference? Any hints, ideas, clarifications or links to examples would be awesome!

Caveats first, solution second:

In my experience, it can be dangerous to select differently structured models based on the ELBO because that objective includes both model accuracy and inference accuracy. For example if one model has an extra million nuisance variables each incurring variational misfit, its ELBO can be worse even if it is a better fit of the observed variables. For that reason I prefer to select model structure outside of inference, via some other posterior predictive accuracy metric like continuous ranked probability score (CRPS).

If your models are structurally similar (e.g. {StudentT, Cauchy, Normal}), you can simply choose via a Bernoulli or Categorical variable and use guide-side enumeration for inference. For example:

def mixture_model(data):
    which = pyro.sample("which", dist.Categorical(torch.ones(3) / 3))
    poutine.mask(model1, which == 0)(data)
    poutine.mask(model2, which == 1)(data)
    poutine.mask(model3, which == 2)(data)

def mixture_guide(data):
    probs = pyro.param("probs", torch.ones(3), constraint=constraints.simplex)
    which = pyro.sample("which", dist.Categorical(probs),
                        infer={"enumerate": "parallel"})
    poutine.mask(guide1, which == 0)(data)
    poutine.mask(guide2, which == 1)(data)
    poutine.mask(guide3, which == 2)(data)

SVI(mixture_model, mixture_guide, ..., TraceEnum_ELBO())
...

Just take care that your models and guides avoid sample site name conflict.

EDIT on second thought inference might be cheaper with sequential rather than parallel enumeration. This should avoid the name conflict issue, but won’t allow vectorized multi-sample inference or sampling:

def mixture_model(data):
    which = pyro.sample("which", dist.Categorical(torch.ones(3) / 3))
    (model1, model2, model3)[int(which)](data)

def mixture_guide(data):
    probs = pyro.param("probs", torch.ones(3), constraint=constraints.simplex)
    which = pyro.sample("which", dist.Categorical(probs),
                        infer={"enumerate": "sequential"})
    (guide1, guide2, guide3)[int(which)](data)

SVI(mixture_model, mixture_guide, ..., TraceEnum_ELBO())
...
1 Like

Thanks a lot for the quick and very useful answer, fritzo! Also, thanks for the warning regarding model selection using ELBO!