Greeting Pyro community!
I am trying to develop a generalized mixture model.
I read GMM pyro tutorial attentively and am taking it as the basis.
The main action takes place here:
with pyro.plate('data', len(data)): assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)
Things to note:
- the main constraint of the approach in the tutorial is that components of the mixture have to be of the same distribution (homogeneous components)
- the parameter should be tensors, to be broadcastable
A typical example of heterogeneous components would be a mixture of (piecewise)-linear regressions. The parameters may have different shapes (but also different semantics).
In order for the enumeration to work parameters for the mixture should be packed into a tensor. The mixture should be packed into a wrapper distribution
WrapperDist that would unpack the parameters and implement
expand as well as aggregated the calls to
log_prob of the mixture components.
Perhaps someone has already tried something similar, if so, can you share?
Or perhaps that is not possible given the current framework?
I realized that even extending GMM to such a generalized mixture is not easy.
I have encountered issues related to correctly expanding the tensor to the batch dimension, to bare initialization of
WrapperDist in the trace etc
I could share a snippet, perhaps it would be helpful?