In the tutorial Dirichlet Process Mixture Models in Pyro the model draws \beta_i \sim Beta(1,\alpha). It then uses “stick-breaking” function mix_weights
to get the probabilities in the Categorical distribution for the latent parameter z_i. However in the guide the link between \beta and z is lost and z is optimized independently. As a result it seems that \beta has no influence on the posterior and can be completely omitted from the model/guide.
Even the truncation function truncate
does not take into account the \beta parameters.
The model can be simplified to:
def model(data):
# with pyro.plate("beta_plate", T-1):
# beta = pyro.sample("beta", Beta(1, torch.tensor(.1).to(cuda_device)))
with pyro.plate("mu_plate", T):
mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2).cuda(), 5 * torch.eye(2).cuda()))
with pyro.plate("data", N):
# z = pyro.sample("z", Categorical(mix_weights(beta)))
z = pyro.sample("z", Categorical( mix_weights(Beta(1, torch.tensor(.1)).sample([T-1]).to(cuda_device)) ))
pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(2).cuda()), obs=data)
The guide becomes:
def guide(data):
# kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T-1]).cuda(), constraint=constraints.positive)
tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2).to(cuda_device), 3 * torch.eye(2).to(cuda_device)).sample([T]))
phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T).to(cuda_device)).sample([N]), constraint=constraints.simplex)
# with pyro.plate("beta_plate", T-1):
# q_beta = pyro.sample("beta", Beta(torch.ones(T-1).to(cuda_device), kappa))
with pyro.plate("mu_plate", T):
q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(2).to(cuda_device)))
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(phi))
Is my line of thought correct?