How to infer inputs of a neural network via MCMC?

I know I can use gradient descent to infer input [parameters] of a neural network model but for some reason I need to infer those parameters via MCMC. To give you more context, the input has low number of parameters (e.g. 10) and they do not come from a specific distribution (e.g. Gaussian, Uniform etc) but are bounded in [0, 1]. During test time, the neural network model produces an image given an input. My goal is to infer the input parameters conditioned on an output. I am pretty new to Pyro and MCMC but I looked at an example here but cannot figure out how to translate what the example is doing to my problem because pyro.sample seems to require using a specific distribution (maybe I can assume Uniform for my parameters?) and also SVI is used in that example and not MCMC. In addition, the VAE model used in the example has a definition that is different that the model I am using and I’m not sure how to modify the example to make it work in my case. So I wonder, how can I condition my model on an [expected] output and infer the input parameters of my model using Pyro and via MCMC?

Here’s a simplified version of the neural network model I’m using:

class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        .
        . # define model layers
        .

    def forward(self, x):
        .
        . # do transposed convolutional operations on x and get predicted_means, predicted_sigma and reconsted_image
        .
        
        elbo = torch.sum(Normal(predicted_means, predicted_sigma).log_prob(reconsted_image), dim=[1,2,3])

        return elbo
    
    def reconstruct(self, x):
        
        # This is the function I use during test time and MCMC needs to use this function
        
        .
        . # do transposed convolutional operations on x and get mu
        .
        
        return torch.clamp(mu, 0, 1)

Hi, I’m not sure I understand your model description - what is reconstruct supposed to do? Do you mean something like this generative model for data x?

class MyModel(nn.Module):
    ...
    def forward(self, x):  # generative model
        with pyro.plate("data", x.shape[0]):
            z = pyro.sample("z", dist.ImproperUniform(dist.constraints.unit_interval, (), (10,)))
            predicted_means, predicted_scales = self.my_neural_net(z)
            return pyro.sample("x", dist.Normal(predicted_means, predicted_scales), obs=x)

Note the use of pyro.distributions.ImproperUniform for your latent parameters, since you don’t want to place a proper prior distribution on them.

The MCMC docs demonstrate how to use the MCMC API. In this case:

nuts_kernel = pyro.infer.NUTS(MyModel(...), adapt_step_size=True)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=500, warmup_steps=300)
mcmc.run(data)
zs = mcmc.get_samples()["z"]
1 Like

Thanks for your response. reconstruct() behaves very similar to forward() as in it takes an input x and transforms it into an output. However, reconsturct() doesn’t update some of the internal variables of the model and also doesn’t apply any loss function to the outputs (e.g. predicted images). So I was initially thinking I should use reconstruct() when trying to infer the inputs but since MCMC requires a likelihood function I think I should use forward() instead.

Also, could you elaborate on what you meant by ... before def forward(self, x)? do you mean I should do something like pyro.module('my_model', self.my_neural_net)? Also, do I still have to do return pyro.sample(“x”, dist.Normal(predicted_means, predicted_scales), obs=x)` if my model is outputting a likelihood score and I want MCMC to operate on that?

After reading some more posts and spending some time on reading Pyro’s documentation, I think I now have a bit better understanding of what is going on. I have another question. When returning a sample in the forward() function you wrote, how can I clamp my samples to be between [0, 1]?

Also, could you elaborate on what you meant by ... before def forward(self, x) ? do you mean I should do something like pyro.module('my_model', self.my_neural_net )?

That depends on your problem. If your neural network’s weights are fixed, there’s no need to do anything Pyro-specific, just define your PyTorch __init__ method that builds self.my_neural_net.

my model is outputting a likelihood score and I want MCMC to operate on that

If you want to do MCMC directly on a custom unnormalized log-density rather than a Pyro model with pyro.sample statements, you can pass a function that computes your density given parameter values to the potential_fn argument of the MCMC kernel, as illustrated in this forum thread.

When returning a sample in the forward() function you wrote, how can I clamp my samples to be between [0, 1]?

Do you mean samples of "x"? If so, you can apply a SigmoidTransform to the distribution of "x" via pyro.distributions.TransformedDistribution. Alternatively, you could use a Bernoulli likelihood and return the mean rather than binarized samples as in the VAE example.

1 Like

Thanks for your prompt response. Regarding mean samples of "x": during training here’s how I am computing the log-likelihood of my neural network outputs in my model class’ original forward() function:

elbo = torch.sum(Normal(predicted_means, predicted_sigma).log_prob(reconsted_image), dim=[1,2,3])

But during test time I need to make sure the generated pixel values are within [0, 1] so I sample from a Normal and clamp the output pixel values via clamp() as follow:

return torch.clamp(mu, 0, 1)

So maybe I should have asked my question in another way: should MCMC use the truncated Gaussian distribution as its likelihood function or not? If so, would using SigmoidTransform do the job? To me, it seems that I should simply truncate the values and not smooth things out (as Sigmoid does)
And if the answer is no, then would using pyro.sample("x", dist.Normal(predicted_means, predicted_scales), obs=x) tell MCMC to basically do what torch.sum(Normal(predicted_means, predicted_sigma).log_prob(reconsted_image), dim=[1,2,3]) does?
Sorry if my questions don’t make that much sense; I’m pretty new to Pyro.

Also, I provided the "guide" function to NUTS and ran inference via mcmc.run() but got the following error:
File “/usr/local/lib/python3.6/dist-packages/pyro/poutine/broadcast_messenger.py”, line 59, in _pyro_sample
f.name, msg[‘name’], f.dim, f.size, target_batch_shape[f.dim]))
ValueError: Shape mismatch inside plate(‘data’) at site x dim -1, 36 vs 64
Trace Shapes:
Param Sites:
encoder$$$conv1.weight 256 3 2 2
encoder$$$conv2.weight 256 256 2 2
encoder$$$conv3.weight 256 263 3 3
encoder$$$conv4.weight 128 263 3 3
encoder$$$conv5.weight 256 128 3 3
encoder$$$conv6.weight 81 256 1 1
inference$$$downsample_x.weight 3 3 4 4
inference$$$upsample_v.weight 7 7 16 16
inference$$$upsample_r.weight 81 81 16 16
inference$$$downsample_u.weight 128 128 4 4
inference$$$core.forget.weight 128 475 5 5
inference$$$core.input.weight 128 475 5 5
inference$$$core.output.weight 128 475 5 5
inference$$$core.state.weight 128 475 5 5
generator$$$upsample_v.weight 7 7 16 16
generator$$$upsample_r.weight 81 81 16 16
generator$$$core.forget.weight 128 217 5 5
generator$$$core.input.weight 128 217 5 5
generator$$$core.output.weight 128 217 5 5
generator$$$core.state.weight 128 217 5 5
generator$$$upsample_h.weight 128 128 4 4
Sample Sites:
data dist |
value 36 |
v_q dist 36 | 7
value 36 | 7

I checked the dimensions of my samples and they seem correct. What could I be doing wrong?

Here’s how I’m doing the computations inside my "guide" function:

    pyro.module("inference", self.inference_network)
    pyro.module("generator", self.generation_network)
    pyro.module("hidden_state", self.hidden_state_network)
    pyro.module("image_generator", self.image_generator_network)
    with pyro.plate("data", x.shape[0]): # x.shape = (batch_dim, 3, 64, 64), batch_dim=36
            v_q = pyro.sample('v_q', pyro.distributions.Uniform(v_q_min, v_q_max).to_event(1)) # v_q.shape = (batch_dim, 7), batch_dim=36

            for l in range(self.L): # do the following process L times
                c_e, h_e = self.inference_network(x_q, v_q, r, c_e, h_e, h_g, u)
                
                mu_q, logvar_q = torch.split(self.hidden_state_network(h_e), 1, dim=1)
                std_q = torch.exp(0.5*logvar_q)
                q = Normal(mu_q, std_q)
                
                z = q.rsample()

                c_g, h_g, u = self.generation_network(v_q, r, c_g, h_g, u, z)
            return pyro.sample("x", pyro.distributions.Normal(self.image_generator_network(u), 0.001).to_event(1), obs=x_q) # x.shape = (batch_dim, 3, 64, 64), batch_dim=36

This gives me the error I posted above. However, if I do with pyro.plate("data"): I don’t get an error and MCMC runs successfully but I’m pretty sure that I shouldn’t be doing this this because the images in each batch are independent. So I need to use with pyro.plate("data", x.shape[0]): but this gives me an error at the moment. So what should I do to resolve this issue?
Just FYI: I am using Pyro 1.5

Update: if I just use with pyro.plate("data"): I get weird results after running MCMC with 100 samples and 300 warmup steps:
Sample: 100%|█████████████████████████████████████████| 300/300 [01:18, 3.84it/s, step size=1.17e-132, acc. prob=0.000]

Sorry, I don’t understand what you mean by this, and I’m generally having trouble following your code and problem descriptions. MCMC doesn’t require a guide, and Pyro’s MCMC kernels don’t take a guide argument. I also see that in your latest code snippet, you’re sampling from q directly (q.rsample()) rather than wrapping it in pyro.sample, which is not compatible with any of Pyro’s inference algorithms, MCMC or otherwise. Can you try explaining what you want to accomplish with this code? What do you mean when you say this snippet is part of your guide function?

I think you might find it helpful to read through Pyro’s introductory tutorials, especially part 1 and part 2 of the Bayesian regression tutorial, which discuss the conceptual and API differences between variational inference and MCMC.

1 Like

I just learned that by “guide” I meant to say “conditioned_model” as in the MCMC example. In a nutshell, my goal is to infer v_q and the code you see above is part of the infer_v_q() function in my model class. To initialize and run MCMC, I ran the following lines:

nuts_kernel = pyro.infer.NUTS(model.infer_v_q, adapt_step_size=True) # i was initially thinking infer_v_q is called "guide" in the context of NUTS
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=100, warmup_steps=300)
mcmc.run(x, v, x_q)

I also replaced the lines for sampling “z” as you recommended as follow but I still get a dimension mismatch error (a different one though):
z = pyro.sample("z", pyro.distributions.Normal(mu_q, std_q).to_event(1) # z.shape = (batch_dim, 1, 16, 16)

And here’s the new error I’m getting on the the above line where z is sampled:

Shape mismatch inside plate(‘data’) at site z dim -1, 36 vs 16

I’m not sure why I get those dimension mismatch errors though. Could it be because I have do upsampling in some of modules in my model? One thing I noticed in the example I’ve seen is that tensors usually have two dimensions. In other words, a tensor containing a batch of 28 x 28 images has shape (batch_dim, 768) but my tensors are 4-dimensional (batch, x, y, z). Could this be the issue?

I also replaced the lines for sampling “z” as you recommended as follow but I still get a dimension mismatch error (a different one though):

You need to declare the three rightmost dimensions of z as dependent with .to_event(3), not just one:

z = pyro.sample("z", pyro.distributions.Normal(mu_q, std_q).to_event(3))

See the tensor shape tutorial for background.

1 Like

Yes I also did that (forgot to mention it here) but I get another error saying:

RuntimeError: Multiple sample sites named ‘z’
Trace Shapes:
Param Sites:
encoder$$$conv1.weight 256 3 2 2
encoder$$$conv2.weight 256 256 2 2
encoder$$$conv3.weight 128 256 3 3
encoder$$$conv4.weight 256 128 2 2
encoder$$$conv5.weight 256 263 3 3
encoder$$$conv6.weight 128 263 3 3
encoder$$$conv7.weight 256 128 3 3
encoder$$$conv8.weight 81 256 1 1
inference$$$downsample_x.weight 3 3 4 4
inference$$$upsample_v.weight 7 7 16 16
inference$$$upsample_r.weight 81 81 16 16
inference$$$downsample_u.weight 128 128 4 4
inference$$$core.forget.weight 128 475 5 5
inference$$$core.input.weight 128 475 5 5
inference$$$core.output.weight 128 475 5 5
inference$$$core.state.weight 128 475 5 5
generator$$$upsample_v.weight 7 7 16 16
generator$$$upsample_r.weight 81 81 16 16
generator$$$core.forget.weight 128 217 5 5
generator$$$core.input.weight 128 217 5 5
generator$$$core.output.weight 128 217 5 5
generator$$$core.state.weight 128 217 5 5
generator$$$upsample_h.weight 128 128 4 4
hidden_state$$$weight 2 128 5 5
image_generator$$$weight 3 128 1 1
Sample Sites:
data dist |
value 36 |
v_q dist 36 | 7
value 36 | 7
z dist 36 | 1 16 16
value 36 | 1 16 16

The code crashes on line z = pyro.sample("z", pyro.distributions.Normal(mu_q, std_q).to_event(3). Could the code be crashing because I have a for loop that samples z multiple L times? If so, how can I resolve this issue given that z depends on the updated values of mu_q and std_q for each iteration of thefor loop?

I made the following change and I don’t get the error saying “multiple sample sites named ‘z’” anymore. However, the code now crashes on the last line where I do return pyro.sample("x", pyro.distributions.Normal(self.image_generator_network(u), 0.001).to_event(1), obs=x_q) # x.shape = (batch_dim, 3, 64, 64), batch_dim=36:

for l in range(self.L):
    .
    .
    .
    z = pyro.sample("z"+str(l), pyro.distributions.Normal(mu_q, std_q).to_event(3))

And here’s the error:

Shape mismatch inside plate(‘data’) at site x dim -1, 36 vs 64
Sample Sites:
data dist |
value 36 |
v_q dist 36 | 7
value 36 | 7
z0 dist 36 | 1 16 16
value 36 | 1 16 16
z1 dist 36 | 1 16 16
value 36 | 1 16 16
z2 dist 36 | 1 16 16
value 36 | 1 16 16
z3 dist 36 | 1 16 16
value 36 | 1 16 16
z4 dist 36 | 1 16 16
value 36 | 1 16 16
z5 dist 36 | 1 16 16
value 36 | 1 16 16
z6 dist 36 | 1 16 16
value 36 | 1 16 16
z7 dist 36 | 1 16 16
value 36 | 1 16 16

Should I also do to_event(3) for the last line given that the last line is basically a reconstruction of x?

I also change to_event(1) to to_event(3) in return pyro.sample("x", pyro.distributions.Normal(self.image_generator_network(u), 0.001).to_event(1), obs=x_q) and things seem to work fine. However, it takes an extremely long time to run a few warmup steps (2-3 minutes for 20 steps and much longer after that) and after that the probability of samples is pretty low (0.05-0.06). I’m not sure if this is relevant to the way I set the dimensions or I should simply run more warmup steps or get more samples. I would appreciate if you can give me some advise on this. I also apologize for going off topic in this thread.