I am trying to define a mixture model

Something like

y_i ~\mid~ \hat{y}_{in}, \hat{y}_{out}, x_i, \sigma_i, \sigma_{out} \sim (1 - g_i) \, \mathcal{N}(\hat{y}_{in}, \sigma_y) + g_i \, \mathcal{N}(\hat{y}_{out}, \sigma_{out})

I am trying the same distributions for both components because I see there is a `MixtureSameFamily`

but I would also like to know a solution to defining different distributions

My toy example is a line with outliers

\hat{y}_{in}(x ~\mid~\alpha, \beta) = \alpha x + \beta

each data point has a Bernoulli probability (0 or 1) to be an outlier or not with a probability g

g_i \sim \mathcal{B}(g)

g sets the ratio of inliers to outliers, it corresponds to the fraction of outliers in our data. One can set a weakly informative prior on g as

g \sim \mathcal{U}(0, 0.5)

(hopefully we do not have more than half of the data being made of outliers)

The following does not work

```
def jax_model_outliers(x=None, y=None, sigma_y=None):
## Define weakly informative Normal priors
beta = numpyro.sample("beta", dist.Normal(0.0, 100))
alpha = numpyro.sample("alpha", dist.Normal(0.0, 100))
## Define Bernoulli inlier / outlier flags according to
## a hyperprior fraction of outliers, itself constrained
## to [0,.5] for symmetry
frac_outliers = numpyro.sample('frac_outliers', dist.Uniform(low=0., high=.5))
## variance of outliers
sigma_y_out = numpyro.sample("sigma_y_out", dist.HalfNormal(100))
with numpyro.plate("data", len(y)):
is_outlier = numpyro.sample('is_outlier',
dist.Bernoulli(frac_outliers),
infer={'enumerate': 'parallel'})
mix_ = dist.Categorical(probs=jnp.array([is_outlier, 1 - is_outlier]))
comp_ = dist.Normal(jnp.array([beta + alpha * x, 0]),
jnp.array([sigma_y, sigma_y_out]))
mixture = dist.MixtureSameFamily(mix_, comp_)
# likelihood
numpyro.sample("obs", mixture, obs=y)
```

```
ValueError: All input arrays must have the same shape.
```

I also tried a different implementation

```
def jax_model_outliers(x=None, y=None, sigma_y=None):
## Define weakly informative Normal priors
beta = numpyro.sample("beta", dist.Normal(0.0, 100))
alpha = numpyro.sample("alpha", dist.Normal(0.0, 100))
## Define Bernoulli inlier / outlier flags according to
## a hyperprior fraction of outliers, itself constrained
## to [0,.5] for symmetry
frac_outliers = numpyro.sample('frac_outliers', dist.Uniform(low=0., high=.5))
## variance of outliers
sigma_y_out = numpyro.sample("sigma_y_out", dist.HalfNormal(100))
with numpyro.plate("data", len(y), dim=-1):
## define the linear model
p_outlier = numpyro.sample('p_outlier',
dist.Bernoulli(frac_outliers),
infer={'enumerate': 'parallel'})
probs = jnp.stack([1 - p_outlier, p_outlier])
mix_ = dist.Categorical(probs=probs)
locs = jnp.stack([beta + alpha * x, jnp.zeros(len(x))])
scales = jnp.stack([sigma_y, sigma_y_out * jnp.ones(len(x))])
comp_ = dist.Normal(locs, scales)
mixture = dist.MixtureSameFamily(mix_, comp_)
# likelihood
numpyro.sample("obs", mixture, obs=y)
```

```
ValueError: Incompatible shapes for broadcasting: ((20,), (2,))
```

`x`

, `y`

, `sy`

are of size 20, only the mixing vector `mix_`

is of size 2. This suggests to me that `MixtureSameFamily`

is not doing what I think.

Thanks for your help.