well something like this kind of works. note however that this is a terrible way to be attempting to solve this problem. discrete latent variables and black box variational inference generally don’t get along together very well. which is why enumeration is preferred whenever possible.
def guide(data):
dirichlet_param = pyro.param("dirichlet_param", torch.ones(K) / K,
constraint=constraints.simplex)
weights = pyro.sample('weights', dist.Dirichlet(dirichlet_param))
scale_para_loc = pyro.param('scale_para_loc', torch.tensor(0.))
scale_para_scale = pyro.param('scale_para_scale', torch.tensor(0.001),
constraint=constraints.positive)
scale = pyro.sample('scale', dist.LogNormal(scale_para_loc, scale_para_scale))
locs_para_loc = pyro.param('locs_para_loc', torch.tensor([0.0, 10.0]))
locs_para_scale = pyro.param('locs_para_scale', 0.001 * torch.ones(K),
constraint=constraints.positive)
with pyro.plate('components', K):
locs = pyro.sample('locs', dist.Normal(locs_para_loc, locs_para_scale))
with pyro.plate('data', len(data)):
logits = pyro.param("logits", 0.1 * torch.randn(len(data), K))
assignment = pyro.sample('assignment', dist.Categorical(logits=logits))