Hi everyone
I have read the Dirichlet Process Mixture Models in Pyro section and I am little confused about the first example about inference. Below is the code and my questions.
def model(data):
with pyro.plate("beta_plate", T-1):
beta = pyro.sample("beta", Beta(1, alpha))
with pyro.plate("mu_plate", T):
mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(mix_weights(beta)))
pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(2)), obs=data)
def guide(data):
kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T-1]), constraint=constraints.positive)
tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2), 3 * torch.eye(2)).sample([T]))
phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)
with pyro.plate("beta_plate", T-1):
q_beta = pyro.sample("beta", Beta(torch.ones(T-1), kappa))
with pyro.plate("mu_plate", T):
q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(2)))
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(phi))
T = 6
optim = Adam({"lr": 0.05})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
losses = []
def train(num_iterations):
pyro.clear_param_store()
for j in tqdm(range(num_iterations)):
loss = svi.step(data)
losses.append(loss)
def truncate(alpha, centers, weights):
threshold = alpha**-1 / 100.
true_centers = centers[weights > threshold]
true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
return true_centers, true_weights
alpha = 0.1
train(1000)
# We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
Bayes_Centers_01, Bayes_Weights_01 = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))
alpha = 1.5
train(1000)
# We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
Bayes_Centers_15, Bayes_Weights_15 = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], color="blue")
plt.scatter(Bayes_Centers_01[:, 0], Bayes_Centers_01[:, 1], color="red")
plt.subplot(1, 2, 2)
plt.scatter(data[:, 0], data[:, 1], color="blue")
plt.scatter(Bayes_Centers_15[:, 0], Bayes_Centers_15[:, 1], color="red")
plt.tight_layout()
plt.show()
The questions are
- If I understand correctly, the base distribution here is MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)), the sampled elements from base distribution which constitutes the G is mu and the corresponding weight is beta. However tau is taken as the mean when plots the cluster center, I wonder if it is correct to use mu as the cluster center since the data is sampled from the distribution which is parameterized with mu.
- In the guide function, it seems we generate a list of tau, and this action will cause q_mu being sampled from different base distributions. Should it be a single vector like tensor(1,1), thus we can sample all q_mu from the same base distribution in each iteration?
Very appreciate your help!