Generalizing mixture model

Greeting Pyro community!

I am trying to develop a generalized mixture model.
I read GMM pyro tutorial attentively and am taking it as the basis.

The main action takes place here:

with pyro.plate('data', len(data)):
    assignment = pyro.sample('assignment', dist.Categorical(weights))
    pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

Things to note:

  • the main constraint of the approach in the tutorial is that components of the mixture have to be of the same distribution (homogeneous components)
  • the parameter should be tensors, to be broadcastable

A typical example of heterogeneous components would be a mixture of (piecewise)-linear regressions. The parameters may have different shapes (but also different semantics).

In order for the enumeration to work parameters for the mixture should be packed into a tensor. The mixture should be packed into a wrapper distribution WrapperDist that would unpack the parameters and implement expand as well as aggregated the calls to rsample and log_prob of the mixture components.

Perhaps someone has already tried something similar, if so, can you share?
Or perhaps that is not possible given the current framework?

I realized that even extending GMM to such a generalized mixture is not easy.

I have encountered issues related to correctly expanding the tensor to the batch dimension, to bare initialization of WrapperDist in the trace etc

I could share a snippet, perhaps it would be helpful?

Thoughts?

Hi @kerm. Were you able to figure out an approach to your problem? I’m also interested in this and I’m stuck as well.