DPMM guide-side enumeration doesn't work?

I have the following hierarchical DPMM for continuous data. When I fit the MAP using my custom guide by enumerating out the labels in the model, it fits quite well. However, when I enumerate out the labels in the guide the model doesn’t fit properly - it tends to dump all of the observations in the same category regardless of the signal in the data.

Why is this the case? I realize that marginalizing out the categories explicitly in the model is different mathematically than marginalizing them out in the guide, but the ELBO gradient does not have any variance in either case so it seems like both should produce similar results.

I guess I’m also wondering generally what the ramifications are for guide-side vs. model-side enumeration in terms of the efficacy of fitting models in Pyro.

n_components = 10
n_obs, n_features = data.shape

def stickbreak(v):
    cumprod_one_minus_v = torch.cumprod(1 - v, dim=-1)
    v_one = pad(v, (0, 1), value=1)
    one_c = pad(cumprod_one_minus_v, (1, 0), value=1)
    return v_one * one_c

def gen_uniform_stick_weights(n_components):
    weights = torch.zeros(n_components - 1)
    weights[0] = 1/n_components

    for i in range(1, n_components - 1):
        stick_remainder = 1 - weights[i - 1]
        weights[i] = (1/(1+n_components-i)) / stick_remainder 

    return weights

def model(data):
    alpha = pyro.sample('alpha', dist.Gamma(1, 1))
    
    with pyro.plate('mixture_weights', n_components - 1):
        v = pyro.sample('v', dist.Beta(1, alpha))

    with pyro.plate('features', n_features):
        with pyro.plate(f'components', n_components):
            mu = pyro.sample('mu', dist.Normal(0, 1))
            sigma_sq = pyro.sample('sigma_sq', dist.InverseGamma(1, 1)) 

    with pyro.plate("data", n_obs):
        label = pyro.sample(f"cat", dist.Categorical(stickbreak(v)))

        for i in pyro.plate('features_', n_features):
            pyro.sample(f'obs_{i}', dist.Normal(mu[label, i], sigma_sq[label, i]), obs=data[:, i])

def guide(data):
    alpha_loc = pyro.param("alpha_loc", torch.tensor(1.), constraint=constraints.positive)
    alpha = pyro.sample("alpha", dist.Delta(alpha_loc))

    with pyro.plate('mixture_weights', n_components - 1):
        v_loc = pyro.param("v_loc", gen_uniform_stick_weights(n_components), constraint=constraints.unit_interval)
        v = pyro.sample("v", dist.Delta(v_loc))

    with pyro.plate('features', n_features):         
        with pyro.plate('components', n_components):
            mu_loc = pyro.param("mu_loc", dist.Normal(0, 1).sample([n_components, n_features]))
            sigma_sq_loc = pyro.param("sigma_sq_loc", dist.InverseGamma(1, 1).sample([n_components, n_features]), constraint=constraints.positive)

            mu = pyro.sample('mu', dist.Delta(mu_loc))
            sigma_sq = pyro.sample('sigma_sq', dist.Delta(sigma_sq_loc))

    with pyro.plate("data", n_obs):
        cat = pyro.sample('cat', dist.Categorical(stickbreak(v)), infer={'enumerate': 'parallel'})

it’s hard to say in general. guide-side enumeration involves an additional variational relaxation. how exactly that impacts the optimization problem is probably pretty problem specific. model-enumeration is probably easier to get working in most cases since it requires fewer guide parameters. guide-side enumeration, like all elbo optimization problems, may work poorly if variational parameters are poorly initialized. one potential up-side of guide-side enumeration is that it opens the door to learning amortized guides.

1 Like

also in your code what is the shape of stickbreak(v)? it looks like you’re not learning individual categorical probabilities for each data point

That’s correct. Stickbreak(v) takes the stick proportions to sequentially break off and returns an n_components long tensor with the mixture weights of the clusters (the length of the last stick is always determined by the length of the second to last stick). If I replace the data plate in the guide with something like this to learn the categorical probs individually,

    with pyro.plate("data", n_obs):
        assignment_probs = pyro.param('assignment_probs', 
                                       torch.ones(len(data), n_components) / n_components,
                                       constraint=constraints.unit_interval)        
        cat = pyro.sample('cat', dist.Categorical(assignment_probs), infer={'enumerate': 'parallel'})

the same phenomenon occurs when I increase the numbers of components beyond the “true” number of components in the model.

I got down this rabbit hole as an simpler version of CrossCat, in which the “view assignments” (one of the latent discrete parameters) cannot be enumerated in the model. I was hoping to learn something about how guide-side enumeration affects inference by working with a simpler case, but since the behavior is not generalizable, I’ll open up a separate thread to keep this on topic :slight_smile: .

i forgot how crosscat works in detail but generally the posterior (and therefore the guide) has fewer statistical independencies than the model and so e.g. you would expect that the marginal posterior cat depends on the data point in question

1 Like