How to properly define a Mixture

I am trying to define a mixture model
Something like
y_i ~\mid~ \hat{y}_{in}, \hat{y}_{out}, x_i, \sigma_i, \sigma_{out} \sim (1 - g_i) \, \mathcal{N}(\hat{y}_{in}, \sigma_y) + g_i \, \mathcal{N}(\hat{y}_{out}, \sigma_{out})

I am trying the same distributions for both components because I see there is a MixtureSameFamily but I would also like to know a solution to defining different distributions

My toy example is a line with outliers
\hat{y}_{in}(x ~\mid~\alpha, \beta) = \alpha x + \beta

each data point has a Bernoulli probability (0 or 1) to be an outlier or not with a probability g

g_i \sim \mathcal{B}(g)

g sets the ratio of inliers to outliers, it corresponds to the fraction of outliers in our data. One can set a weakly informative prior on g as

g \sim \mathcal{U}(0, 0.5)

(hopefully we do not have more than half of the data being made of outliers)

The following does not work

def jax_model_outliers(x=None, y=None, sigma_y=None):

    ## Define weakly informative Normal priors 
    beta = numpyro.sample("beta", dist.Normal(0.0, 100))
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 100))

    ## Define Bernoulli inlier / outlier flags according to 
    ## a hyperprior fraction of outliers, itself constrained
    ## to [0,.5] for symmetry
    frac_outliers = numpyro.sample('frac_outliers', dist.Uniform(low=0., high=.5))

    ## variance of outliers
    sigma_y_out = numpyro.sample("sigma_y_out", dist.HalfNormal(100))

    with numpyro.plate("data", len(y)):
        is_outlier = numpyro.sample('is_outlier', 
                                    dist.Bernoulli(frac_outliers), 
                                    infer={'enumerate': 'parallel'})

        mix_ = dist.Categorical(probs=jnp.array([is_outlier, 1 - is_outlier]))
        comp_ = dist.Normal(jnp.array([beta + alpha * x, 0]),
                            jnp.array([sigma_y, sigma_y_out]))
        mixture = dist.MixtureSameFamily(mix_, comp_)

        # likelihood
        numpyro.sample("obs", mixture, obs=y)
ValueError: All input arrays must have the same shape.

I also tried a different implementation

def jax_model_outliers(x=None, y=None, sigma_y=None):

    ## Define weakly informative Normal priors 
    beta = numpyro.sample("beta", dist.Normal(0.0, 100))
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 100))

    ## Define Bernoulli inlier / outlier flags according to 
    ## a hyperprior fraction of outliers, itself constrained
    ## to [0,.5] for symmetry
    frac_outliers = numpyro.sample('frac_outliers', dist.Uniform(low=0., high=.5))

    ## variance of outliers
    sigma_y_out = numpyro.sample("sigma_y_out", dist.HalfNormal(100))

    with numpyro.plate("data", len(y), dim=-1):
        ## define the linear model
        p_outlier = numpyro.sample('p_outlier', 
                                   dist.Bernoulli(frac_outliers), 
                                   infer={'enumerate': 'parallel'})
        
        probs = jnp.stack([1 - p_outlier, p_outlier])
        
        mix_ = dist.Categorical(probs=probs)
        locs = jnp.stack([beta + alpha * x, jnp.zeros(len(x))])
        scales = jnp.stack([sigma_y, sigma_y_out * jnp.ones(len(x))])
        comp_ = dist.Normal(locs, scales)
        mixture = dist.MixtureSameFamily(mix_, comp_)

        # likelihood
        numpyro.sample("obs", mixture, obs=y)

ValueError: Incompatible shapes for broadcasting: ((20,), (2,))

x, y, sy are of size 20, only the mixing vector mix_ is of size 2. This suggests to me that MixtureSameFamily is not doing what I think.

Thanks for your help.

If you print out the shapes of the variables, you might see that

print(p_outlier.shape)
probs = jnp.stack([1 - p_outlier, p_outlier])
print(probs.shape)

returns unexpected shapes. I think you need jnp.stack(..., -1) there. Similarly, you need to specify the axis that you need to stack at other places. In any case, please print out the shapes of the variables and make sure it match enumeration semantics (not that p_outlier is enumerated so its shape will be either (len(y),) (when enumeration is not activated) or (2, 1) (when enumeration is activated).

def jax_model_outliers(x=None, y=None, sigma_y=None):

    ## Define weakly informative Normal priors 
    beta = numpyro.sample("beta", dist.Normal(0.0, 100))
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 100))

    ## Define Bernoulli inlier / outlier flags according to 
    ## a hyperprior fraction of outliers, itself constrained
    ## to [0,.5] for symmetry
    frac_outliers = numpyro.sample('frac_outliers', 
                                   dist.Uniform(low=0., high=.5))

    ## variance of outliers
    sigma_y_out = numpyro.sample("sigma_y_out", dist.HalfNormal(100))

    #ypred_in = numpyro.sample("ypred_in", 
    #                          dist.Normal(beta + alpha * x, sigma_y))
    # 
    #ypred_out = numpyro.sample("ypred_out", dist.Normal(jnp.zeros(len(x)), 
    #                                                    sigma_y_out))

    with numpyro.plate("data", len(y), dim=-1):
        ## define the linear model
        p_outlier = numpyro.sample('p_outlier', 
                                   dist.Bernoulli(frac_outliers), 
                                   infer={'enumerate': 'parallel'})
        
        probs = jnp.stack([1 - p_outlier, p_outlier])
        mix_ = dist.Categorical(probs=probs)
        locs = jnp.stack([beta + alpha * x, jnp.zeros(len(x))])
        scales = jnp.stack([sigma_y, sigma_y_out * jnp.ones(len(x))])

        print("shapes",
        "\nlocs", locs.shape,
        "\nscales", scales.shape, 
        "\nprobs", probs.shape, 
        "\nmix", mix_.shape(), 
        "\np_outlier", p_outlier.shape)

        comp_ = dist.Normal(locs, scales)
        mixture = dist.MixtureSameFamily(mix_, comp_)

        # likelihood
        numpyro.sample("obs", mixture, obs=y)

numpyro.render_model(jax_model_outliers, model_args=(x, y, sy,),
                     render_distributions=True)
shapes 
locs (2, 20) 
scales (2, 20) 
probs (2, 20) 
mix (2,) 
p_outlier (20,)

If I use

probs = jnp.stack([1 - p_outlier, p_outlier], -1)
shapes 
locs (2, 20) 
scales (2, 20) 
probs (20, 2) 
mix (20,) 
p_outlier (20,)

AssertionError: Component distribution batch shape last dimension (size=20) needs to correspond to the mixture_size={mixture_size}!``

Not sure why you didn’t use jnp.stack(..., -1) at other places but are those shapes and error message expected? I think here are some questions that you can try to figure out

  • is the error message clear?
  • what is the component distribution batch shape? is it expected?
  • what is its last dimension? is 20 the expected number of components?
  • what is the mixture size? does it have the same value as the number of components?
  • do locs, scales have expected shapes

20 is the length of the data (toy example)
2 is the number of components to the mixture.

the default is axis=-1 in jnp.stack.

Is there a way to define a variable that is an operation like the following?

jnp.log(1 - p_outlier) * ypred_in +
            jnp.log(p_outlier) * ypred_out)

Please check jnp.stack for its usage.

is 20 the expected number of components?

Based on your message, I guess this is not expected? Can you figure out why the error message says that the number of components are 20 rather than 2. With locs shape = (2, 20), what is the number of components for this locs parameter? From the error message, it says that the number of components is 20. Why? If we swap its axis to (20, 2), should the number of components of locs will be 2,…

I’m not sure what you have figured out from your last message so I just sketch some questions I just asked myself.

Is there a way to define a variable that is an operation like the following?
jnp.log(1 - p_outlier) * ypred_in +
jnp.log(p_outlier) * ypred_out)

Those are numerical computations so you can just set

a = jnp.log(1 - p_outlier) * ypred_in + jnp.log(p_outlier) * ypred_out

for that variable a. Are you trying to calculate the log probability by hand?

If you are asking for the topic model, I think you can remove is_outlier variable and define

mix_ = dist.Categorical(probs=jnp.stack([1 - p_outlier, p_outlier], -1))

If you want to use enumeration (rather than MixtureSameFamily), then keep is_outlier, remove mix_ and define

locs = jnp.where(is_outlier, locs_out, locs_in)
scales = jnp.where(is_outlier, scales_out, scales_in)

then use Normal(locs, scales) for the likelihood.

So trying to clarify, the number of components is 2 (here 2 Gaussians) and we have 20 observed points.

probs=jnp.stack([1 - p_outlier, p_outlier], -1)
gives a vector of (20, 2), but we need (2, 20) to align with the rest unless I need to switch them all?

a = jnp.log(1 - p_outlier) * ypred_in + jnp.log(p_outlier) * ypred_out
This is a simple mixture model of 2 distributions that could also be of different families.
But that raises an error that this is not a distribution object.

locs = jnp.where(is_outlier, locs_out, locs_in) is a different way but that won’t generalize well

I tried to use axis=-1 everywhere.
The following gives me a very obscure message

def jax_model_outliers(x=None, y=None, sigma_y=None):

    ## Define weakly informative Normal priors 
    beta = numpyro.sample("beta", dist.Normal(0.0, 100))
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 100))

    ## Define Bernoulli inlier / outlier flags according to 
    ## a hyperprior fraction of outliers, itself constrained
    ## to [0,.5] for symmetry
    frac_outliers = numpyro.sample('frac_outliers', 
                                   dist.Uniform(low=0., high=.5))

    ## variance of outliers
    sigma_y_out = numpyro.sample("sigma_y_out", dist.HalfNormal(100))

    # ypred_in = numpyro.sample("ypred_in", 
    #                          dist.Normal(beta + alpha * x, sigma_y))
     
    # ypred_out = numpyro.sample("ypred_out", dist.Normal(jnp.zeros(len(x)), 
    #                                                    sigma_y_out))

    with numpyro.plate("data", len(y), dim=-1):
        ## define the linear model
        p_outlier = numpyro.sample('p_outlier', 
                                   dist.Bernoulli(frac_outliers), 
                                   infer={'enumerate': 'parallel'})
        
        probs = jnp.stack([1 - p_outlier, p_outlier], axis=-1)
        mix_ = dist.Categorical(probs=probs)
        locs = jnp.stack([beta + alpha * x, jnp.zeros(len(x))], -1)
        scales = jnp.stack([sigma_y, sigma_y_out * jnp.ones(len(x))], -1)

        print("shapes",
        "\nlocs", locs.shape,
        "\nscales", scales.shape, 
        "\nprobs", probs.shape, 
        "\nmix", mix_.shape(), 
        "\np_outlier", p_outlier.shape,
        #"\nypred_in", ypred_in.shape,
        #"\nypred_out", ypred_out.shape,
        )

        comp_ = dist.Normal(locs, scales)
        mixture = dist.MixtureSameFamily(mix_, comp_)

        # likelihood
        numpyro.sample("obs", mixture, obs=y)

numpyro.render_model(jax_model_outliers, model_args=(x, y, sy,),
                     render_distributions=True)

ValueError: data type <class 'numpy.int32'> not inexact

Any idea what this means?