Greeting Pyro community!

I am trying to develop a generalized mixture model.

I read GMM pyro tutorial attentively and am taking it as the basis.

The main action takes place here:

```
with pyro.plate('data', len(data)):
assignment = pyro.sample('assignment', dist.Categorical(weights))
pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)
```

Things to note:

- the main constraint of the approach in the tutorial is that components of the mixture have to be of the same distribution (homogeneous components)
- the parameter should be tensors, to be broadcastable

A typical example of heterogeneous components would be a mixture of (piecewise)-linear regressions. The parameters may have different shapes (but also different semantics).

In order for the enumeration to work parameters for the mixture should be packed into a tensor. The mixture should be packed into a wrapper distribution `WrapperDist`

that would unpack the parameters and implement `expand`

as well as aggregated the calls to `rsample`

and `log_prob`

of the mixture components.

Perhaps someone has already tried something similar, if so, can you share?

Or perhaps that is not possible given the current framework?

I realized that even extending GMM to such a generalized mixture is not easy.

I have encountered issues related to correctly expanding the tensor to the batch dimension, to bare initialization of `WrapperDist`

in the trace etc

I could share a snippet, perhaps it would be helpful?

Thoughts?