How to use Pyro to solve the model inversion problem

I currently have a very complex model, which needs to input 6 variables to generate a vector. I tried to use pyro to solve him. However, the problem I have now is that when I train the mcmc model and try to predict new data, I cannot get the inversion results (i.e., the values of six variables in the new validation set). Do you know how to solve this problem?

import numpy as onp
from scipy.optimize import minimize
from scipy.stats import gaussian_kde
import jax.numpy as np
from jax import random, lax
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import torch


class Laplace(dist.Distribution):
arg_constraints = {‘loc’: dist.constraints.real, ‘scale’: dist.constraints.positive}
support = dist.constraints.real
reparametrized_params = [‘loc’, ‘scale’]

def __init__(self, loc=0., scale=1., validate_args=None):
    self.loc, self.scale = dist.util.promote_shapes(loc, scale)
    batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
    super().__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
    eps = random.laplace(key, shape=sample_shape + self.batch_shape + self.event_shape)
    return self.loc + eps * self.scale

def log_prob(self, value):
    normalize_term = np.log(1/(2*self.scale))
    value_scaled = np.abs(value - self.loc) / self.scale
    return -1*value_scaled + normalize_term

def mean(self):
    return np.broadcast_to(self.loc, self.batch_shape)

def variance(self):
    return np.broadcast_to(2 * self.scale ** 2, self.batch_shape)

def forward(x1, x2,x3, x4,x5, x6, s):
results = s * torch.sqrt(x12 + x22+ x32+ x42+ x52+ x62)
return torch.tensor([results,results,results])

def model(obs): ## just an example
x1 = pyro.sample(‘X1’, dist.Uniform(-2, 2))
x2 = pyro.sample(‘X2’, dist.Uniform(-2, 2))
x3 = pyro.sample(‘X3’, dist.Uniform(-2, 2))
x4 = pyro.sample(‘X4’, dist.Uniform(-2, 2))
x5 = pyro.sample(‘X5’, dist.Uniform(-2, 2))
x6 = pyro.sample(‘X6’, dist.Uniform(-2, 2))
s = pyro.sample(‘S’, dist.Normal(19.5, .5))
t = forward(x1, x2,x3, x4,x5, x6, s)
pyro.sample(‘obs’, dist.Normal(t, 3/2), obs=obs)

if name == ‘main’:

kernel = NUTS(model)
mcmc = MCMC(kernel,num_samples=50, warmup_steps=5)[[19,19,19],[20,20,20],[19,19,19]]).to(torch.float32))
from pyro.infer import MCMC, NUTS, Predictive
sample_list = mcmc.get_samples()
cc = Predictive(model, mcmc.get_samples())(torch.tensor([[21,21,21],[21,21,21],[21,21,21]]).to(torch.float32))


I’m not sure why you included this Laplace distribution in your post, but if I understand your question correctly, what you want is access to X1, X2, etc in cc.
To get this, you will have to specify return_sites when you construct the Predictive object:

## ...
sample_list = mcmc.get_samples()
pred = Predictive(
    return_sites=("X1", "X2", "X3", "X4", "X5", "X6", "S", "obs") ## add this to get X1 etc.

cc = pred(torch.tensor([[21,21,21],[21,21,21],[21,21,21]]).to(torch.float32))

If you want to record other intermediate values, you could use an pyro.deterministic statement. For instance if you want to record t, you would change the model to

def model(obs): ## just an example
    ## sample x1, ..., x6 and s ...
    t = forward(x1, x2, x3, x4, x5, x6, s)
    pyro.deterministic("t", t) ## records t
    pyro.sample("obs", dist.Normal(t, 3/2), obs=obs)