Customize potential_fn in MCMC with a pytorch nn model


I am trying to customize a potential function as input to NUTS sampler. My question is how to make z occur in the computation graph of my nn model, so that pyro could compute grad(potential_energy, z_nodes) in potential_grad.

    def potential_fn(z):
        output = model(train_x)
        logp = loss(output, train_y)
    return logp

Above version of my code will have a runtime error

File ".../anaconda3/lib/python3.7/site-packages/torch/autograd/", line 204, in grad
inputs, allow_unused)
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Thanks in advance.

fill_ is an in-place op and is not differentiable. ideally you rewrite model so that you can pass z as an arg and use it in a way that maintains differentiability