I’ve made a classifier from a mixture model as per the GMM tutorial for a model with subsampling and data masking for missing data. My inference runs and produces an output, but my classifier only predicts cluster membership for subsample\_size number of data points, not the full dataset. Does anyone have an idea what could be driving this?
I am using an AutoDelta guide, with seeding based on trialing a number of initializations as per the tutorials. Below are my model, the classifier and the call to the classifier after running SVI - these all mirror extending the GMM tutorial to matrix rather than vector data, allowing missingness and subsampling.
Model
@config_enumerate
def betaMixtureModel(
data,
num_clusters,
data_mask=None,
subsample_size=None,
):
batch_mask = None
data_mask_passed = data_mask is not None
N, P = data.shape
K = num_clusters
measurement_plate = numpyro.plate("measurements", P, dim=-1)
component_plate = numpyro.plate("components", K, dim=-2)
data_plate = numpyro.plate("data", size=N, dim=-2, subsample_size=subsample_size)
# Global variables.
concentration = numpyro.param(
"concentration",
1e0,
constraints=constraints.greater_than(1e-6)
)
weights = numpyro.sample(
"weights",
dist.Dirichlet(concentration * jnp.ones(K) / K),
)
# Component plates
with component_plate as k:
with measurement_plate as p:
locs = numpyro.sample(
"locs",
dist.Beta(1., 1.),
)
precisions = numpyro.sample(
"precisions",
dist.Gamma(1., 1.),
)
with data_plate as ind:
batch = data[ind]
if data_mask_passed:
batch_mask = data_mask[ind]
# Local variables.
assignment = numpyro.sample(
"assignment",
dist.Categorical(weights),
infer={"enumerate": "parallel"},
)
with measurement_plate as p:
numpyro.sample(
"obs",
dist.BetaProportion(
Vindex(locs)[..., assignment, p],
Vindex(precisions)[..., assignment, p]
),
obs=batch,
obs_mask=batch_mask,
)
Classifier function
Function as per the GMM tutorial with arguments updated to match my model.
def classifier(
trained_model,
data,
num_clusters,
data_mask=None,
subsample_size=None,
temperature=0,
rng_key=PRNGKey(1024),
first_available_dim=-3
):
inferred_model = infer_discrete(
trained_model,
temperature=temperature,
first_available_dim=first_available_dim,
rng_key=rng_key
) # set first_available_dim to avoid conflict with data plate
seeded_inferred_model = handlers.seed(inferred_model, rng_key)
trace = handlers.trace(seeded_inferred_model).get_trace(data, num_clusters, data_mask, subsample_size)
return trace["assignment"]["value"]
Calling of Classifier
trained_global_guide = handlers.substitute(
handlers.seed(global_guide, PRNGKey(seed)),
global_svi_result.params
) # substitute trained params
guide_trace = (
handlers.trace(trained_global_guide)
.get_trace(data, num_clusters, data_mask)
) # record the globals
trained_model = handlers.replay(betaMixtureModel, trace=guide_trace) # replay the globals
temp_0_assignments = classifier(trained_model, data, num_clusters, data_mask, None, rng_key=PRNGKey(0))