I have the following hierarchical DPMM for continuous data. When I fit the MAP using my custom guide by enumerating out the labels in the model, it fits quite well. However, when I enumerate out the labels in the guide the model doesn’t fit properly - it tends to dump all of the observations in the same category regardless of the signal in the data.
Why is this the case? I realize that marginalizing out the categories explicitly in the model is different mathematically than marginalizing them out in the guide, but the ELBO gradient does not have any variance in either case so it seems like both should produce similar results.
I guess I’m also wondering generally what the ramifications are for guide-side vs. model-side enumeration in terms of the efficacy of fitting models in Pyro.
n_components = 10
n_obs, n_features = data.shape
def stickbreak(v):
cumprod_one_minus_v = torch.cumprod(1 - v, dim=-1)
v_one = pad(v, (0, 1), value=1)
one_c = pad(cumprod_one_minus_v, (1, 0), value=1)
return v_one * one_c
def gen_uniform_stick_weights(n_components):
weights = torch.zeros(n_components - 1)
weights[0] = 1/n_components
for i in range(1, n_components - 1):
stick_remainder = 1 - weights[i - 1]
weights[i] = (1/(1+n_components-i)) / stick_remainder
return weights
def model(data):
alpha = pyro.sample('alpha', dist.Gamma(1, 1))
with pyro.plate('mixture_weights', n_components - 1):
v = pyro.sample('v', dist.Beta(1, alpha))
with pyro.plate('features', n_features):
with pyro.plate(f'components', n_components):
mu = pyro.sample('mu', dist.Normal(0, 1))
sigma_sq = pyro.sample('sigma_sq', dist.InverseGamma(1, 1))
with pyro.plate("data", n_obs):
label = pyro.sample(f"cat", dist.Categorical(stickbreak(v)))
for i in pyro.plate('features_', n_features):
pyro.sample(f'obs_{i}', dist.Normal(mu[label, i], sigma_sq[label, i]), obs=data[:, i])
def guide(data):
alpha_loc = pyro.param("alpha_loc", torch.tensor(1.), constraint=constraints.positive)
alpha = pyro.sample("alpha", dist.Delta(alpha_loc))
with pyro.plate('mixture_weights', n_components - 1):
v_loc = pyro.param("v_loc", gen_uniform_stick_weights(n_components), constraint=constraints.unit_interval)
v = pyro.sample("v", dist.Delta(v_loc))
with pyro.plate('features', n_features):
with pyro.plate('components', n_components):
mu_loc = pyro.param("mu_loc", dist.Normal(0, 1).sample([n_components, n_features]))
sigma_sq_loc = pyro.param("sigma_sq_loc", dist.InverseGamma(1, 1).sample([n_components, n_features]), constraint=constraints.positive)
mu = pyro.sample('mu', dist.Delta(mu_loc))
sigma_sq = pyro.sample('sigma_sq', dist.Delta(sigma_sq_loc))
with pyro.plate("data", n_obs):
cat = pyro.sample('cat', dist.Categorical(stickbreak(v)), infer={'enumerate': 'parallel'})