Mixture model with discrete data in Numpyro

Hi.

Apologies for the rather long post. I am trying to modify the code from the Gaussian mixture model tutorial (I have also pulled bits of code from various other posts on this forum and elsewhere - if you recognise your code, thank you very much!) to compute a mixture model where the data is Bernoulli distributed categorical. This is the GMM code that works when I fit with both HMC and SVI

## Generate the data 
n = 1000 # Total number of samples
k = 2  # Number of clusters
dim=3 # Number of dimensions
p_real = np.array([0.3, 0.7]) # Probability of choosing each cluster

mu0=[10, 10, 10]
mu1=[-5, -3, -3]
mus=[mu0,mu1]

sigma0=1 
sigma1=0.5
sigmas=[sigma0,sigma1]

clusters = np.random.choice(k, size=1, p=p_real)
data=np.random.multivariate_normal(mus[clusters[0]],sigmas[clusters[0]]*np.eye(dim), (1))
for i in range(1,n):
    clusters = np.random.choice(k, size=1, p=p_real)
    mu=mus[clusters[0]]
    sigma=sigmas[clusters[0]]*np.eye(dim)
    data_point=np.random.multivariate_normal(mu, sigma,(1))
    data=np.concatenate((data,data_point), axis=0)

data

Here is the model

@config_enumerate
def model(K,dim,data=None):
    
    cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(jnp.ones(K) / float(K)))

    with numpyro.plate("variables", dim):
        theta = numpyro.sample("theta", dist.HalfCauchy(10.0))
        sigma = jnp.sqrt(theta)
    
    with numpyro.plate('components', K):
        locs = numpyro.sample('locs',dist.MultivariateNormal(jnp.zeros(dim),10*jnp.eye(dim))) 
        
    with numpyro.plate('data', len(data)):
        assignment = numpyro.sample('assignment', dist.Categorical(cluster_proba)) 

        numpyro.sample(
            'obs', 
            dist.MultivariateNormal(locs[assignment, :], covariance_matrix=jnp.diag(sigma)), 
            obs=data
        )

Fitting:

# HMC
rng_key = jax.random.PRNGKey(0)
num_warmup, num_samples = 500, 1000
kernel = HMC(model, step_size=.01, trajectory_length=1,
             adapt_step_size=False, adapt_mass_matrix=False)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
)
mcmc.run(rng_key, data=data ,K=2,dim=3)

## SVI
k = 2
global_model = numpyro.handlers.block(
    numpyro.handlers.seed(model, jax.random.PRNGKey(0)),
    hide_fn=lambda site: site["name"]
    not in ["cluster_proba", "variables", "theta", "components", "locs"],
)
init_vals = {
    "cluster_proba": jnp.ones(k) / float(k),
    "theta": jnp.sqrt(data.var(axis=0)/2),
    "locs": data[jax.random.categorical(
        jax.random.PRNGKey(0), jnp.ones(len(data)) / len(data), shape=(k,)
    )]
}
guide = ag.AutoDelta(
    global_model,
    init_loc_fn=init_to_value(values=init_vals)
)
optimizer = numpyro.optim.Adam(step_size=0.005)
svi = SVI(model, guide, optimizer, loss=TraceEnum_ELBO())

This all works well and produces estimates that are close to the values I put into the data generation. Now, to modify this model to run on categorical, Bernoulli distributed data first we need a multivariate Bernoulli distribution, as that’s not provided as far as I can tell:

class MultivariateBernoulli(dist.Distribution):
    support = constraints.real_vector

    def __init__(self, phi):
        super(MultivariateBernoulli, self).__init__(event_shape=(1, )) 
        self.phi = phi

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def log_prob(self, value):
    
        ps_clamped = clamp_probs(self.phi)
        # most of this code is pulled from numpyro.dists.Bernoulli
        return jnp.sum(
            jnp.asarray(
              xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped) # assuming independence of the variables for now.
            ),
            axis=1
        )

And the model, written in a similar way to the GMM above:

@config_enumerate
def discrete_mixture_model(K, X=None):
    
    N, D = X.shape
    cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K))) 
    
    with numpyro.plate('components', D):
        with numpyro.plate("cluster", K):
            # Note, this nested plate statement is needed because without it I get an plate definition needed error
            phi = numpyro.sample('phi', dist.Beta(2.0, 2.0)) 

    with numpyro.plate('data', N):
        
        assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba)) 

        numpyro.sample(
            'obs', 
            MultivariateBernoulli(phi[assignment, :]), 
            obs=X
        )

I can’t get this model to fit with HMC and SVI as above. The error messages I get look like this:

<snip>
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\_src\numpy\util.py:425, in _broadcast_to(arr, shape)
    422 nlead = len(shape) - len(arr_shape)
    423 shape_tail = shape[nlead:]
    424 compatible = all(core.definitely_equal_one_of_dim(arr_d, [1, shape_d])
--> 425                  for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
    426 if nlead < 0 or not compatible:
    427   msg = "Incompatible shapes for broadcasting: {} and requested shape {}"

ValueError: safe_zip() argument 2 is shorter than argument 1

Curiously, I can get the model to fit using DiscreteHMCGibbs if I comment out the @config_enumerate decorator. With that in, I get an error stating that no discrete latent variables are in the model (which is odd as there clearly are discrete latent variables in the model).

Can anyone help me to get this model fitting with regular HMC and SVI please?

Many thanks :slight_smile:

The discrete hmc gibbs works for models with discrete latent variables. If you marginalize out the discrete site, there is no discrete latent variable in your model. In that case, you can just use NUTS or HMC as usual. For SVI, I think you can use TraceEnum_ELBO objective.

Thanks fehiepsi,

The issue is I don’t understand the difference between the GMM example given above (That fits with HMC and a SVI) and the discrete mixture model that doesn’t fit with regular HMC. They both look like they have latent discrete variables in them to me. Do you have any suggestions for how I can modify the discrete mixture model to behave more like the GMM?

Thanks again :slight_smile:

I think the main difference is at the likelihood. Your log prob computation has sum(..., axis=1). I would expect axis is -1 or something to account for batched data/params (under enumeration). The event shape (1,) also looks strange to me. I would expect a non-trivial event shape here.

Ah! Thanks - I will experiment with that. When I was working on the log prob I was trying to replicate the output of dist.MultivariateNormal but I guess I could have been more general.

Hi fehiepsi,

I’ve changed the event shape and the summation in the custom distribution so it’s now:

class MultivariateBernoulli(dist.Distribution):
    support = constraints.real_vector

    def __init__(self, phi):
        super(MultivariateBernoulli, self).__init__(event_shape=(len(phi), )) 
        self.phi = phi

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def log_prob(self, value):
    
        ps_clamped = clamp_probs(self.phi)

        return jnp.sum(
            jnp.asarray(
              xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped) # assuming independence of the variables.
            ),
            axis=-1
        )

but I still have the same problem. With this model

@config_enumerate
def discrete_mixture_model(K, X=None):
    
    N, D = X.shape
    cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K))) 
    
    with numpyro.plate('components', D):
        with numpyro.plate("cluster", K):
            phi = numpyro.sample('phi', dist.Beta(2.0, 2.0)) 

    with numpyro.plate('data', N):
        
        assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba)) 
        
        numpyro.sample(
            'obs', 
            MultivariateBernoulli(phi[assignment, :]), 
            obs=X,
        )

And this attempt to run with HMC

rng_key = jax.random.PRNGKey(0)
num_warmup, num_samples = 500, 1000

kernel = HMC(discrete_mixture_model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)

I get a very long error message that ends with this:

<snip>
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\_src\numpy\util.py:425, in _broadcast_to(arr, shape)
    422 nlead = len(shape) - len(arr_shape)
    423 shape_tail = shape[nlead:]
    424 compatible = all(core.definitely_equal_one_of_dim(arr_d, [1, shape_d])
--> 425                  for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
    426 if nlead < 0 or not compatible:
    427   msg = "Incompatible shapes for broadcasting: {} and requested shape {}"

ValueError: safe_zip() argument 2 is shorter than argument 1

I was trying to figure out what the issue was so I put in some print statements to look at the array shapes and this is what is output:

cluster_proba (3,)
phi (3, 6)
assignment (600,)
cluster_proba (3,)
phi (3, 6)
assignment (3, 1, 1)

I assume that the assignment changing shape is to do with the enumeration, and I also assume that there shouldn’t be three dimensions in the end? Any suggestions?

Something like len(phi) won’t work for batched phi. I think the event shape is phi.shape[-1:] and batch shape is phi.shape[:-1].

Btw, I’m not sure but I think your MultivariateBernoulli is Bernoulli(phi).to_event(1). Could you try to see whether they are the same?

1 Like

Hi fehiepsi,

Thank you so much for your help. This fixes the problem! For completeness in case someone else is interested.

Custom distribution:

class MultivariateBernoulli(dist.Distribution):
    support = constraints.real_vector

    def __init__(self, phi):
        super(MultivariateBernoulli, self).__init__(batch_shape=phi.shape[:-1], event_shape=phi.shape[-1:]) 
        self.phi = phi

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def log_prob(self, value):
    
        ps_clamped = clamp_probs(self.phi)

        return jnp.sum(
            jnp.asarray(
              xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped) # assuming independence of the variables.
            ),
            axis=-1
        )

Model:

@config_enumerate
def discrete_mixture_model(K, X=None):
    
    N, D = X.shape
    cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K))) 
        
    with numpyro.plate('components', D):
        with numpyro.plate("cluster", K):
            phi = numpyro.sample('phi', dist.Beta(2.0, 2.0)) 

    with numpyro.plate('data', N):
        
        assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba)) 
        
        numpyro.sample(
            'obs', 
            MultivariateBernoulli(phi[assignment, :]), 
            obs=X,
        )

When I check the log probs of Bernoulli(phi).to_event(1) I get the same results as for my MultivariateBernoulli so you are right about that too. In the long run I need to have a multivariate Bernoulli distribution that includes correlations between the variables (as in this paper [1206.1874] Multivariate Bernoulli distribution) so I will probably stick with the custom class for that.

Thanks again! :slight_smile:

1 Like