Caveats first, solution second:
In my experience, it can be dangerous to select differently structured models based on the ELBO because that objective includes both model accuracy and inference accuracy. For example if one model has an extra million nuisance variables each incurring variational misfit, its ELBO can be worse even if it is a better fit of the observed variables. For that reason I prefer to select model structure outside of inference, via some other posterior predictive accuracy metric like continuous ranked probability score (CRPS).
If your models are structurally similar (e.g. {StudentT, Cauchy, Normal}), you can simply choose via a Bernoulli or Categorical variable and use guide-side enumeration for inference. For example:
def mixture_model(data):
which = pyro.sample("which", dist.Categorical(torch.ones(3) / 3))
poutine.mask(model1, which == 0)(data)
poutine.mask(model2, which == 1)(data)
poutine.mask(model3, which == 2)(data)
def mixture_guide(data):
probs = pyro.param("probs", torch.ones(3), constraint=constraints.simplex)
which = pyro.sample("which", dist.Categorical(probs),
infer={"enumerate": "parallel"})
poutine.mask(guide1, which == 0)(data)
poutine.mask(guide2, which == 1)(data)
poutine.mask(guide3, which == 2)(data)
SVI(mixture_model, mixture_guide, ..., TraceEnum_ELBO())
...
Just take care that your models and guides avoid sample site name conflict.
EDIT on second thought inference might be cheaper with sequential rather than parallel enumeration. This should avoid the name conflict issue, but won’t allow vectorized multi-sample inference or sampling:
def mixture_model(data):
which = pyro.sample("which", dist.Categorical(torch.ones(3) / 3))
(model1, model2, model3)[int(which)](data)
def mixture_guide(data):
probs = pyro.param("probs", torch.ones(3), constraint=constraints.simplex)
which = pyro.sample("which", dist.Categorical(probs),
infer={"enumerate": "sequential"})
(guide1, guide2, guide3)[int(which)](data)
SVI(mixture_model, mixture_guide, ..., TraceEnum_ELBO())
...