Mixture model with subsampling predicting membership only for subsampled points

I’ve made a classifier from a mixture model as per the GMM tutorial for a model with subsampling and data masking for missing data. My inference runs and produces an output, but my classifier only predicts cluster membership for subsample\_size number of data points, not the full dataset. Does anyone have an idea what could be driving this?

I am using an AutoDelta guide, with seeding based on trialing a number of initializations as per the tutorials. Below are my model, the classifier and the call to the classifier after running SVI - these all mirror extending the GMM tutorial to matrix rather than vector data, allowing missingness and subsampling.

Model

@config_enumerate
def betaMixtureModel(
    data, 
    num_clusters, 
    data_mask=None, 
    subsample_size=None,
):
    batch_mask = None
    data_mask_passed = data_mask is not None
    N, P = data.shape
    K = num_clusters

    measurement_plate = numpyro.plate("measurements", P, dim=-1)
    component_plate = numpyro.plate("components", K, dim=-2)
    data_plate = numpyro.plate("data", size=N, dim=-2, subsample_size=subsample_size)

    # Global variables.
    concentration = numpyro.param(
        "concentration", 
        1e0, 
        constraints=constraints.greater_than(1e-6)
    )
    weights = numpyro.sample(
        "weights", 
        dist.Dirichlet(concentration * jnp.ones(K) / K),
    )

    # Component plates
    with component_plate as k:
        with measurement_plate as p:
            locs = numpyro.sample(
                "locs", 
                dist.Beta(1., 1.),
            )
            precisions = numpyro.sample(
                "precisions", 
                dist.Gamma(1., 1.),
            )

    with data_plate as ind:
        batch = data[ind]
        if data_mask_passed:
            batch_mask = data_mask[ind]

        # Local variables.
        assignment = numpyro.sample(
            "assignment", 
            dist.Categorical(weights),
            infer={"enumerate": "parallel"},
        )

        with measurement_plate as p:
            numpyro.sample(
                "obs", 
                dist.BetaProportion(
                    Vindex(locs)[..., assignment, p], 
                    Vindex(precisions)[..., assignment, p]
                ),
                obs=batch,
                obs_mask=batch_mask,
            )

Classifier function

Function as per the GMM tutorial with arguments updated to match my model.


def classifier(
        trained_model,
        data, 
        num_clusters, 
        data_mask=None, 
        subsample_size=None,
        temperature=0, 
        rng_key=PRNGKey(1024),
        first_available_dim=-3
    ):
    inferred_model = infer_discrete(
        trained_model, 
        temperature=temperature, 
        first_available_dim=first_available_dim, 
        rng_key=rng_key
    )  # set first_available_dim to avoid conflict with data plate
    seeded_inferred_model = handlers.seed(inferred_model, rng_key)
    trace = handlers.trace(seeded_inferred_model).get_trace(data, num_clusters, data_mask, subsample_size)
    return trace["assignment"]["value"]

Calling of Classifier

trained_global_guide = handlers.substitute(
    handlers.seed(global_guide, PRNGKey(seed)),
    global_svi_result.params
)  # substitute trained params
guide_trace = (
    handlers.trace(trained_global_guide)
    .get_trace(data, num_clusters, data_mask)
  )  # record the globals

trained_model = handlers.replay(betaMixtureModel, trace=guide_trace)  # replay the globals
temp_0_assignments = classifier(trained_model, data, num_clusters, data_mask, None, rng_key=PRNGKey(0))

You might want to provide create_plates for AutoDelta.

1 Like

Thank you for the suggestion!

After adopting this, I hit an error that not all the plates created are instances of a numpyro.plate. In the elbo.loss call (below) it looks like (based on print statements) create_plates is called a few times, and the third time it fails - none of the returned objects are numpyro.plates. In the previous iterations they are plates. Do you have any idea what could be driving this?

Elbo loss call

elbo.loss(
        PRNGKey(0), 
        {}, 
        betaMixtureModel, 
        global_guide, 
        data, 
        num_clusters, 
        data_mask, 
        subsample_size
    )

Create plates function

def create_plates(
        data, 
        num_clusters, 
        data_mask=None, 
        subsample_size=None,
    ):
    """
    Create plates for Auto Guide; required for subsampling
    """
    N, P = data.shape
    K = num_clusters
    measurement_plate = numpyro.plate("measurements", P, dim=-1)
    component_plate = numpyro.plate("components", K, dim=-2)
    data_plate = numpyro.plate("data", size=N, dim=-2, subsample_size=subsample_size)
    
    plates = [measurement_plate, component_plate, data_plate]
    return plates

Guide call

global_guide = AutoDelta(
        global_model, 
        init_loc_fn=init_to_value(values=init_values),
        create_plates=create_plates,
        # init_scale=0.01
    )

I dont know what the error is so I’m not sure what’s going on. Probably we dont have support for subsample+enumerate+autoguide

For debugging, I would suggest to use a simpler model with e.g. only 1 subsample plate and a discrete latent variable. Then add complexity later.

My apologies, for clarity and any future reference the error is:

AssertionError: create_plates() returned a non-plate

and occurs on the third call to create_plates() within the elbo.loss function (based on print statements), passing the first two times.

Thanks for your help - I’ll move to a manual guide in that case.

Good to know that using custom guide might work for your model. If you get the same issue, please make a small reproducible code or let us know more info about the error (e.g. if you already print things out, you can let us know what those plates are when the error happens, or more info about “third call” etc.)