I am new to probabilistic programming and perhaps I met an obvious problem here.
I am studying LDA by this code and want to get the ‘document-topic’ distribution theta_d
which is sampled in
@pyro.poutine.broadcast
def model(data):
phi = pyro.sample("phi",dist.Dirichlet(torch.ones([K, V])).to_event(1))
for d in pyro.plate("documents", D):
theta_d = pyro.sample("theta_%d"%d, dist.Dirichlet(torch.ones([K])))
with pyro.plate("words_%d"%d, N[d]):
z = pyro.sample("z_%d"%d, dist.Categorical(theta_d))
pyro.sample("w_%d"%d, dist.Categorical(phi[z]), obs=data[d])
@pyro.poutine.broadcast
@config_enumerate(default='parallel')
def guide(data):
beta_q = pyro.param("beta_q", torch.ones([K, V]),constraint=constraints.positive)
phi_q = pyro.sample("phi",dist.Dirichlet(beta_q).to_event(1))
for d in pyro.plate("documents", D):
alpha_q = pyro.param("alpha_q_%d"%d, torch.ones([K]),constraint=constraints.positive)
q_theta_d = pyro.sample("theta_%d"%d, dist.Dirichlet(alpha_q))
with pyro.plate("words_%d"%d, N[d]):
q_i = pyro.param("q_%d"%d, torch.randn([N[d], K]).exp(), constraint=constraints.simplex)
pyro.sample("z_%d"%d, dist.Categorical(q_i))
It seems that after training by
adam_params = {"lr": 0.01, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
pyro.clear_param_store()
svi = SVI(model, guide, optimizer, loss=TraceEnum_ELBO(max_iarange_nesting=1))
losses = []
for _ in range(3000):
loss = svi.step(data)
losses.append(loss)
plt.plot(list(range(3000)), losses)
plt.title('ELBO')
plt.xlabel('step')
plt.show()
We can get alpha_q
by
alpha_q = pyro.param("alpha_q_%d"%d)
But how to get each theta
which indicates the document-topic distribution?
Do I need to do something like
for i in range(D):
q_theta_d = pyro.sample("theta_%d"% i, dist.Dirichlet(pyro.param("alpha_q_%d" % i)))
print("[%s]: %s" % ("theta_%d"% i, q_theta_d.data.numpy()))
pos = torch.argmax(q_theta_d)
print(pos)
And there is one more question, can we do minibatch training instead of using the whole data
and how do implement it?
If I use the subsample_size
in for d in pyro.plate("documents", D):
, will it work?
Thanks for replying!