Mixture in guide but not model

I have a scenario where I am trying to use a guide which corresponds to a mixture but the model does not (although it does end up being multimodal). Minimal version (gist since I’m failing at formatting: https://gist.github.com/davidaknowles/1d45c56b40ffcf573cc6a5743c6e2f25):

import pyro
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from torch.distributions import constraints

data = 2. * torch.Tensor( [-1., -0.5, -0.5, .5, .8, 1.] )

def model(data):
guide_efficacy = pyro.sample(‘guide_efficacy’, dist.Beta(1., 1.).expand([len(data)]).to_event(1) )
gene_essentiality = pyro.sample(“gene_essentiality”, dist.Normal(0., 5.))
mean = gene_essentiality * guide_efficacy
with pyro.plate(“data”, len(data)):
obs = pyro.sample(“obs”, dist.Normal(mean, 1.), obs = data)

def guide(data):
prob = pyro.param(“prob”, torch.tensor(0.5), constraint=constraints.unit_interval)
z = pyro.sample(‘assignment’, dist.Bernoulli(prob)).long()
ge_mean = pyro.param(“ge_mean”, torch.ones(2))
ge_scale = pyro.param(“ge_scale”, torch.ones(2), constraint=constraints.positive)
gene_essentiality = pyro.sample(“gene_essentiality”, dist.Normal(ge_mean[z], ge_scale[z]))
guide_efficacy_a = pyro.param(‘guide_efficacy_a’, torch.ones([2,len(data)]), constraint=constraints.positive)
guide_efficacy_b = pyro.param(‘guide_efficacy_b’, torch.ones([2,len(data)]), constraint=constraints.positive)
guide_efficacy = pyro.sample(“guide_efficacy”, dist.Beta(guide_efficacy_a[z,:], guide_efficacy_b[z,:]))
return assignment, gene_essentiality, guide_efficacy

TraceEnum_ELBO().loss(model, config_enumerate(guide, “parallel”), data)

The error I’m getting is
ValueError: Error while packing tensors at site ‘guide_efficacy’:
Invalid tensor shape.
Allowed dims: -2
Actual shape: (2, 1, 6)

Any pointers much appreciated.

hi @daknowles i’m not entirely sure if that (namely the use of auxiliary discrete latent variables in the guide that do not appear in the model) is expected to work. depending on what your precise use case is it might be easier to use MixtureSameFamily. would that be sufficient for your use case?

1 Like

Thanks, helpful to know it’s not an expected use case.
I don’t think MixtureSameFamily cuts it because I have two sets of parameters sharing the same mixture assignment. My current solution is to use NUTS which should suffice!

@daknowles does something like this work for you?

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from torch.distributions import constraints

data = 2. * torch.Tensor( [-1., -0.5, -0.5, .5, .8, 1.] )


@config_enumerate
def model(data):
    gene_essentiality = pyro.sample("gene_essentiality", dist.Normal(0., 5.))
    prob = pyro.param("prob", torch.tensor(0.5), constraint=constraints.unit_interval)
    z = pyro.sample('assignment', dist.Bernoulli(prob)).long()
    with pyro.plate("data", len(data)):
        guide_efficacy = pyro.sample('guide_efficacy', dist.Beta(1., 1.))
        mean = gene_essentiality * guide_efficacy
        pyro.sample("obs", dist.Normal(mean, 1.), obs=data)

@config_enumerate
def guide(data):
    prob = pyro.param("prob", torch.tensor(0.5), constraint=constraints.unit_interval)
    z = pyro.sample('assignment', dist.Bernoulli(prob)).long()
    ge_mean = pyro.param("ge_mean", torch.ones(2))
    ge_scale = pyro.param("ge_scale", torch.ones(2), constraint=constraints.positive)
    guide_efficacy_a = pyro.param('guide_efficacy_a', torch.ones([2, len(data)]), constraint=constraints.positive)
    guide_efficacy_b = pyro.param('guide_efficacy_b', torch.ones([2, len(data)]), constraint=constraints.positive)
    gene_essentiality = pyro.sample("gene_essentiality", dist.Normal(ge_mean[z], ge_scale[z]))
    with pyro.plate("data", len(data)):
        guide_efficacy = pyro.sample("guide_efficacy", dist.Beta(guide_efficacy_a[z.squeeze(-1)],
                                                                 guide_efficacy_b[z.squeeze(-1)]))

TraceEnum_ELBO(max_plate_nesting=1).loss(model, guide, data)

i basically put a “dummy” auxiliary variable in the model that has no downstream dependencies. this will compute the ELBO on the extended space, i.e. the expectation with respect to the discrete assignment variable is outside of the logs.

do you really want z outside of the plate? not sure what the precise use case is.

Got distracted by others things but just wanted to loop back to say this does work, thank you! I guess it introduces a new term into the ELBO (due to the prior on z) but since you put a uniform prior it’s a constant (so not effect on inference).

Now I just need to figure out how to integrate this into the rest of the model. Happy New Year!

@daknowles from what i recall without writing down any equations this does change the inference because you’re moving summation outside the logarithms. actually this can be understood as a “hierarchical variational model” and you would presumably want to learn the prior on the auxiliary to get a tighter elbo.

1 Like