How to use Pyro to solve the model inversion problem

I currently have a very complex model, which needs to input 6 variables to generate a vector. I tried to use pyro to solve him. However, the problem I have now is that when I train the mcmc model and try to predict new data, I cannot get the inversion results (i.e., the values of six variables in the new validation set). Do you know how to solve this problem?

###_____________###
import numpy as onp
from scipy.optimize import minimize
from scipy.stats import gaussian_kde
import jax.numpy as np
from jax import random, lax
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import torch

NUM_WARMUP, NUM_SAMPLES = 1000, 5

class Laplace(dist.Distribution):
arg_constraints = {‘loc’: dist.constraints.real, ‘scale’: dist.constraints.positive}
support = dist.constraints.real
reparametrized_params = [‘loc’, ‘scale’]

def __init__(self, loc=0., scale=1., validate_args=None):
    self.loc, self.scale = dist.util.promote_shapes(loc, scale)
    batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
    super().__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
    eps = random.laplace(key, shape=sample_shape + self.batch_shape + self.event_shape)
    return self.loc + eps * self.scale

def log_prob(self, value):
    normalize_term = np.log(1/(2*self.scale))
    value_scaled = np.abs(value - self.loc) / self.scale
    return -1*value_scaled + normalize_term

@property
def mean(self):
    return np.broadcast_to(self.loc, self.batch_shape)

@property
def variance(self):
    return np.broadcast_to(2 * self.scale ** 2, self.batch_shape)

def forward(x1, x2,x3, x4,x5, x6, s):
results = s * torch.sqrt(x12 + x22+ x32+ x42+ x52+ x62)
return torch.tensor([results,results,results])

def model(obs): ## just an example
x1 = pyro.sample(‘X1’, dist.Uniform(-2, 2))
x2 = pyro.sample(‘X2’, dist.Uniform(-2, 2))
x3 = pyro.sample(‘X3’, dist.Uniform(-2, 2))
x4 = pyro.sample(‘X4’, dist.Uniform(-2, 2))
x5 = pyro.sample(‘X5’, dist.Uniform(-2, 2))
x6 = pyro.sample(‘X6’, dist.Uniform(-2, 2))
s = pyro.sample(‘S’, dist.Normal(19.5, .5))
t = forward(x1, x2,x3, x4,x5, x6, s)
pyro.sample(‘obs’, dist.Normal(t, 3/2), obs=obs)

if name == ‘main’:

kernel = NUTS(model)
mcmc = MCMC(kernel,num_samples=50, warmup_steps=5) 
mcmc.run(torch.tensor([[19,19,19],[20,20,20],[19,19,19]]).to(torch.float32))
from pyro.infer import MCMC, NUTS, Predictive
sample_list = mcmc.get_samples()
cc = Predictive(model, mcmc.get_samples())(torch.tensor([[21,21,21],[21,21,21],[21,21,21]]).to(torch.float32))

###_____________###

I’m not sure why you included this Laplace distribution in your post, but if I understand your question correctly, what you want is access to X1, X2, etc in cc.
To get this, you will have to specify return_sites when you construct the Predictive object:

## ...
sample_list = mcmc.get_samples()
pred = Predictive(
    model, 
    mcmc.get_samples(),
    return_sites=("X1", "X2", "X3", "X4", "X5", "X6", "S", "obs") ## add this to get X1 etc.
)

cc = pred(torch.tensor([[21,21,21],[21,21,21],[21,21,21]]).to(torch.float32))

If you want to record other intermediate values, you could use an pyro.deterministic statement. For instance if you want to record t, you would change the model to

def model(obs): ## just an example
    ## sample x1, ..., x6 and s ...
    t = forward(x1, x2, x3, x4, x5, x6, s)
    pyro.deterministic("t", t) ## records t
    pyro.sample("obs", dist.Normal(t, 3/2), obs=obs)