Thank you for the answer!

I apapted code from here for the Gaussian mixture and it is suits me

```
class MixtureGaussian(dist.Distribution):
def __init__(self, loc, scale, mixing_probs, validate_args=None):
expand_shape = jax.lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale), jnp.shape(mixing_probs))
self._gaussian = dist.Normal(loc=loc, scale=scale).expand(expand_shape)
self._categorical = dist.Categorical(jnp.broadcast_to(mixing_probs, expand_shape))
super(MixtureGaussian, self).__init__(batch_shape=expand_shape[:-1], validate_args=validate_args)
def sample(self, key, sample_shape=()):
key, key_idx = random.split(key)
samples = self._gaussian.sample(key, sample_shape)
ind = self._categorical.sample(key_idx, sample_shape)
return jnp.take_along_axis(samples, ind[..., None], -1)[..., 0]
def log_prob(self, value):
probs_mixture = self._gaussian.log_prob(value[..., None])
sum_probs = self._categorical.logits + probs_mixture
return jax.nn.logsumexp(sum_probs, axis=-1)
```