Hello devs. I’m not sure why I’m getting this error
TypeError: cumsum requires ndarray or scalar arguments, got at position 0.
when running this code:
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.distributions.mixtures import MixtureGeneral
rng_key = jax.random.PRNGKey(0)
component_distributions = [
dist.TruncatedDistribution(dist.Normal(0, 1), low=0),
dist.TruncatedDistribution(dist.Normal(0, 100), low=0),
]
mixed = MixtureGeneral(
mixing_distribution=dist.Categorical(probs=[0.9, 0.1]),
component_distributions=component_distributions
)
mixed.sample(rng_key, sample_shape=(10,))
Help is much appreciated. Thanks!