Error mixing two Truncated Normal distributions using dist.TruncatedDistribution

Hello devs. I’m not sure why I’m getting this error

TypeError: cumsum requires ndarray or scalar arguments, got  at position 0.

when running this code:

import jax
import numpyro
import numpyro.distributions as dist
from numpyro.distributions.mixtures import MixtureGeneral

rng_key = jax.random.PRNGKey(0)

component_distributions = [
    dist.TruncatedDistribution(dist.Normal(0, 1), low=0),
    dist.TruncatedDistribution(dist.Normal(0, 100), low=0),
]

mixed = MixtureGeneral(
    mixing_distribution=dist.Categorical(probs=[0.9, 0.1]),
    component_distributions=component_distributions
)

mixed.sample(rng_key, sample_shape=(10,))

Help is much appreciated. Thanks!

looks like you need probs=jax.numpy.array([0.9, 0.1])

1 Like

Thank you @martinjankowiak, that solved it!