Hi.
Apologies for the rather long post. I am trying to modify the code from the Gaussian mixture model tutorial (I have also pulled bits of code from various other posts on this forum and elsewhere - if you recognise your code, thank you very much!) to compute a mixture model where the data is Bernoulli distributed categorical. This is the GMM code that works when I fit with both HMC and SVI
## Generate the data
n = 1000 # Total number of samples
k = 2 # Number of clusters
dim=3 # Number of dimensions
p_real = np.array([0.3, 0.7]) # Probability of choosing each cluster
mu0=[10, 10, 10]
mu1=[-5, -3, -3]
mus=[mu0,mu1]
sigma0=1
sigma1=0.5
sigmas=[sigma0,sigma1]
clusters = np.random.choice(k, size=1, p=p_real)
data=np.random.multivariate_normal(mus[clusters[0]],sigmas[clusters[0]]*np.eye(dim), (1))
for i in range(1,n):
clusters = np.random.choice(k, size=1, p=p_real)
mu=mus[clusters[0]]
sigma=sigmas[clusters[0]]*np.eye(dim)
data_point=np.random.multivariate_normal(mu, sigma,(1))
data=np.concatenate((data,data_point), axis=0)
data
Here is the model
@config_enumerate
def model(K,dim,data=None):
cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(jnp.ones(K) / float(K)))
with numpyro.plate("variables", dim):
theta = numpyro.sample("theta", dist.HalfCauchy(10.0))
sigma = jnp.sqrt(theta)
with numpyro.plate('components', K):
locs = numpyro.sample('locs',dist.MultivariateNormal(jnp.zeros(dim),10*jnp.eye(dim)))
with numpyro.plate('data', len(data)):
assignment = numpyro.sample('assignment', dist.Categorical(cluster_proba))
numpyro.sample(
'obs',
dist.MultivariateNormal(locs[assignment, :], covariance_matrix=jnp.diag(sigma)),
obs=data
)
Fitting:
# HMC
rng_key = jax.random.PRNGKey(0)
num_warmup, num_samples = 500, 1000
kernel = HMC(model, step_size=.01, trajectory_length=1,
adapt_step_size=False, adapt_mass_matrix=False)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
)
mcmc.run(rng_key, data=data ,K=2,dim=3)
## SVI
k = 2
global_model = numpyro.handlers.block(
numpyro.handlers.seed(model, jax.random.PRNGKey(0)),
hide_fn=lambda site: site["name"]
not in ["cluster_proba", "variables", "theta", "components", "locs"],
)
init_vals = {
"cluster_proba": jnp.ones(k) / float(k),
"theta": jnp.sqrt(data.var(axis=0)/2),
"locs": data[jax.random.categorical(
jax.random.PRNGKey(0), jnp.ones(len(data)) / len(data), shape=(k,)
)]
}
guide = ag.AutoDelta(
global_model,
init_loc_fn=init_to_value(values=init_vals)
)
optimizer = numpyro.optim.Adam(step_size=0.005)
svi = SVI(model, guide, optimizer, loss=TraceEnum_ELBO())
This all works well and produces estimates that are close to the values I put into the data generation. Now, to modify this model to run on categorical, Bernoulli distributed data first we need a multivariate Bernoulli distribution, as that’s not provided as far as I can tell:
class MultivariateBernoulli(dist.Distribution):
support = constraints.real_vector
def __init__(self, phi):
super(MultivariateBernoulli, self).__init__(event_shape=(1, ))
self.phi = phi
def sample(self, key, sample_shape=()):
raise NotImplementedError
def log_prob(self, value):
ps_clamped = clamp_probs(self.phi)
# most of this code is pulled from numpyro.dists.Bernoulli
return jnp.sum(
jnp.asarray(
xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped) # assuming independence of the variables for now.
),
axis=1
)
And the model, written in a similar way to the GMM above:
@config_enumerate
def discrete_mixture_model(K, X=None):
N, D = X.shape
cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K)))
with numpyro.plate('components', D):
with numpyro.plate("cluster", K):
# Note, this nested plate statement is needed because without it I get an plate definition needed error
phi = numpyro.sample('phi', dist.Beta(2.0, 2.0))
with numpyro.plate('data', N):
assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba))
numpyro.sample(
'obs',
MultivariateBernoulli(phi[assignment, :]),
obs=X
)
I can’t get this model to fit with HMC and SVI as above. The error messages I get look like this:
<snip>
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\_src\numpy\util.py:425, in _broadcast_to(arr, shape)
422 nlead = len(shape) - len(arr_shape)
423 shape_tail = shape[nlead:]
424 compatible = all(core.definitely_equal_one_of_dim(arr_d, [1, shape_d])
--> 425 for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
426 if nlead < 0 or not compatible:
427 msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
ValueError: safe_zip() argument 2 is shorter than argument 1
Curiously, I can get the model to fit using DiscreteHMCGibbs
if I comment out the @config_enumerate
decorator. With that in, I get an error stating that no discrete latent variables are in the model (which is odd as there clearly are discrete latent variables in the model).
Can anyone help me to get this model fitting with regular HMC and SVI please?
Many thanks