How to enforce Pyro to ONLY compute the posterior for a subset of the sample site variables?

Recently I learned out how to use MCMC to infer input parameters of a neural network. After running the code I noticed that it takes a very long time for MCMC to finish running. By very long, I mean it might take 1-3 hours (or even more) to run MCMC with 40 warmup steps and 100 samples depending on the kernel parameters. This is very unacceptable to me specially since the neural network model I’m using can process thousands of images per minute. Then I learned that this is probably because MCMC is in fact estimating a distribution for all of my sample sites which have very large dimensions (see the dimensions below in the code). However, I am only interested in inferring a posterior distribution for one of the variable v_q after conditioning the model’s output on the observation x_q (an image). So I wonder, is there a way to tell MCMC that I’m only interested in inferring a posterior distribution for v_q so that it somehow ignores the other variables?

Here’s my attempt: I set z to z=torch.distributions.Normal(mu_q, std_q).rsample() and MCMC inference runs very fast but acc. prob is very close to 0 (e.g. 0.01) and step_size keeps going down (it got down to 1e-140 after 200 warmup steps). So I assume this is not the right thing to do.

Also, inference is much faster if I set adapt_step_size to False but the results are really bad (acc. prob is around 0.1) when I manually set step_size to a value that a previous inference run with adap_step_size found.

And here’s the code

class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        . # define model layers

    def infer_v_q(self, x, v, x_q, dataset):
        . # do some other operations here to get c_e, h_e, h_g, u and some other variables

        # x.shape = (batch_dim, 3, 64, 64), batch_dim=36
        with pyro.plate("data", x.shape[0]):
                # v_q.shape = (batch_dim, 7), batch_dim=36
                v_q = pyro.sample('v_q', pyro.distributions.Uniform(v_q_min, v_q_max).to_event(1))

                for l in range(self.L):
                    # note that the followings are done L times so I have L number of z sample sites
                    c_e, h_e = self.inference_network(x_q, v_q, r, c_e, h_e, h_g, u)
                    mu_q, logvar_q = torch.split(self.hidden_state_network(h_e), 1, dim=1)
                    std_q = torch.exp(0.5*logvar_q)
                    # z.shape = (batch_dim, 1, 16, 16)
                    z = pyro.sample("z"+str(l), pyro.distributions.Normal(mu_q, std_q).to_event(3))

                    c_g, h_g, u = self.generation_network(v_q, r, c_g, h_g, u, z)
                # x.shape = (batch_dim, 3, 64, 64)
                return pyro.sample("x", pyro.distributions.Normal(self.image_generator_network(u), 0.001).to_event(3), obs=x_q)

# Here's how I run MCMC inference
nuts_kernel = pyro.infer.NUTS(model.infer_v_q, adapt_step_size=True, step_size=1e-9, jit_compile=True, ignore_jit_warnings=True)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=200, warmup_steps=100, num_chains=1), v, x_q)

# get samples
mcmc.get_samples()["v_q"] # I only need this

# I don't want MCMC to get a posterior for the "z" variables but it's doing it at the moment:

And here are my sample sites:

Sample Sites:
data dist |
value 36 |
v_q dist 36 | 7
value 36 | 7
z0 dist 36 | 1 16 16
value 36 | 1 16 16
z1 dist 36 | 1 16 16
value 36 | 1 16 16
z2 dist 36 | 1 16 16
value 36 | 1 16 16
z3 dist 36 | 1 16 16
value 36 | 1 16 16
z4 dist 36 | 1 16 16
value 36 | 1 16 16
z5 dist 36 | 1 16 16
value 36 | 1 16 16
z6 dist 36 | 1 16 16
value 36 | 1 16 16
z7 dist 36 | 1 16 16
value 36 | 1 16 16

I would recommend instead using variational inference and defining a guide. For the parameters you want to estimate with uncertainty, sample from some approximate posterior distribution with learnable parameters. For parameters you want to point estimate, sample from a Delta distribution with learnable parameters. We have some machinery to automate this in easyguide.

Thanks for the suggestion. However, for some reason I need to use MCMC to infer the parameters.

Is this a limitation of Pyro that we cannot set a flag for the variables so that MCMC ignores computing a posterior for them?

As you probably saw in my question, using z=torch.distributions.Normal(mu_q, std_q).rsample() instead of z = pyro.sample("z"+str(l), pyro.distributions.Normal(mu_q, std_q).to_event(3)) sped things up a lot but the inference logs suggested to me that the result is going to be really bad because step size was super small and acceptance probability was almost 0. But still I am not sure if that would also suggest that the posterior I’d get this way for my variable of interest would be incorrect. Could you clarify whether or not this approach is correct or every variable within with pyro.plate() must be defined via a pyro.sample() primitive?

@Warrior HMC is not compatible with random loss functions. You’ll need to either freeze non-inferred parameters or use some other inference method.

1 Like

By “freeze” do you mean those variables must not be a random variable that come from a distribution such as torch.distributions.Normal (i.e. they are the result of a deterministic computation)? If not, how can I freeze those parameters in Pyro? In particular, how can I freeze the original z=torch.distributions.Normal(mu_q, std_q).rsample() so that Pyro would ignore it during inference? Also, I was using NUTS in this particular example, in case that makes a difference.

By “freeze parameters” I simply mean ensure they are constant across model invocations. You can create constant tensors any way you like. Just make sure you don’t call random number generators during model invocations. For example if you had a model like

def model():
    frozen_param = pyro.sample("frozen_param", dist.Normal(0, 1))
    inferred_param = pyro.sample("inferred_param",
                                 dist.Normal(frozen_param, 1))

you could replace it with

class Model:
    def __init__(self):
        self.frozen_param = dist.Normal(0, 1).sample()
    def __call__(self):
        inferred_param = pyro.sample("inferred_param",
                                     dist.Normal(self.frozen_param, 1))
model = Model()

But Python is a big language and I’m sure you’ll want to freeze your parameter in a different way. Just ensure the model is deterministic conditioned on inferred parameters.

1 Like

I tried doing this but in my problem z_i depends on v_q and I cannot freeze z_i before getting an estimate for v_q. So the only option for me to get a posterior for v_q is to use SVI, which is undesirable to me at the moment. I wonder, do you think this would be a limitation of Pyro that it is not possible to get posterior only for a subset of random variables in the sample site (i.e. treat everything else as a blackbox). If so, I can open a GitHub issue on this and request this feature.

As I understand, your issue is a fundamental limitation of MCMC, and is not specific to Pyro. If VI is not suitable for your needs, I can only recommend searching the literature for more exotic inference algorithms beyond MCMC and VI.