Recently I learned out how to use MCMC to infer input parameters of a neural network. After running the code I noticed that it takes a very long time for MCMC to finish running. By very long, I mean it might take 1-3 hours (or even more) to run MCMC with 40 warmup steps and 100 samples depending on the kernel parameters. This is very unacceptable to me specially since the neural network model I’m using can process thousands of images per minute. Then I learned that this is probably because MCMC is in fact estimating a distribution for all of my sample sites which have very large dimensions (see the dimensions below in the code). However, I am only interested in inferring a posterior distribution for one of the variable v_q
after conditioning the model’s output on the observation x_q
(an image). So I wonder, is there a way to tell MCMC that I’m only interested in inferring a posterior distribution for v_q
so that it somehow ignores the other variables?
Here’s my attempt: I set z to z=torch.distributions.Normal(mu_q, std_q).rsample()
and MCMC inference runs very fast but acc. prob
is very close to 0 (e.g. 0.01) and step_size
keeps going down (it got down to 1e-140 after 200 warmup steps). So I assume this is not the right thing to do.
Also, inference is much faster if I set adapt_step_size
to False
but the results are really bad (acc. prob
is around 0.1) when I manually set step_size
to a value that a previous inference run with adap_step_size
found.
And here’s the code
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
.
. # define model layers
.
def infer_v_q(self, x, v, x_q, dataset):
.
. # do some other operations here to get c_e, h_e, h_g, u and some other variables
.
# x.shape = (batch_dim, 3, 64, 64), batch_dim=36
with pyro.plate("data", x.shape[0]):
# v_q.shape = (batch_dim, 7), batch_dim=36
v_q = pyro.sample('v_q', pyro.distributions.Uniform(v_q_min, v_q_max).to_event(1))
for l in range(self.L):
# note that the followings are done L times so I have L number of z sample sites
c_e, h_e = self.inference_network(x_q, v_q, r, c_e, h_e, h_g, u)
mu_q, logvar_q = torch.split(self.hidden_state_network(h_e), 1, dim=1)
std_q = torch.exp(0.5*logvar_q)
# z.shape = (batch_dim, 1, 16, 16)
z = pyro.sample("z"+str(l), pyro.distributions.Normal(mu_q, std_q).to_event(3))
c_g, h_g, u = self.generation_network(v_q, r, c_g, h_g, u, z)
# x.shape = (batch_dim, 3, 64, 64)
return pyro.sample("x", pyro.distributions.Normal(self.image_generator_network(u), 0.001).to_event(3), obs=x_q)
# Here's how I run MCMC inference
nuts_kernel = pyro.infer.NUTS(model.infer_v_q, adapt_step_size=True, step_size=1e-9, jit_compile=True, ignore_jit_warnings=True)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=200, warmup_steps=100, num_chains=1)
mcmc.run(x, v, x_q)
# get samples
mcmc.get_samples()["v_q"] # I only need this
# I don't want MCMC to get a posterior for the "z" variables but it's doing it at the moment:
mcmc.get_samples()["z1"]
.
.
.
mcmc.get_samples()["z7"]
And here are my sample sites:
Sample Sites:
data dist |
value 36 |
v_q dist 36 | 7
value 36 | 7
z0 dist 36 | 1 16 16
value 36 | 1 16 16
z1 dist 36 | 1 16 16
value 36 | 1 16 16
z2 dist 36 | 1 16 16
value 36 | 1 16 16
z3 dist 36 | 1 16 16
value 36 | 1 16 16
z4 dist 36 | 1 16 16
value 36 | 1 16 16
z5 dist 36 | 1 16 16
value 36 | 1 16 16
z6 dist 36 | 1 16 16
value 36 | 1 16 16
z7 dist 36 | 1 16 16
value 36 | 1 16 16