I currently have a very complex model, which needs to input 6 variables to generate a vector. I tried to use pyro to solve him. However, the problem I have now is that when I train the mcmc model and try to predict new data, I cannot get the inversion results (i.e., the values of six variables in the new validation set). Do you know how to solve this problem?

###_____________###

import numpy as onp

from scipy.optimize import minimize

from scipy.stats import gaussian_kde

import jax.numpy as np

from jax import random, lax

import pyro

import pyro.distributions as dist

from pyro.infer import MCMC, NUTS

import torch

NUM_WARMUP, NUM_SAMPLES = 1000, 5

class Laplace(dist.Distribution):

arg_constraints = {‘loc’: dist.constraints.real, ‘scale’: dist.constraints.positive}

support = dist.constraints.real

reparametrized_params = [‘loc’, ‘scale’]

```
def __init__(self, loc=0., scale=1., validate_args=None):
self.loc, self.scale = dist.util.promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
eps = random.laplace(key, shape=sample_shape + self.batch_shape + self.event_shape)
return self.loc + eps * self.scale
def log_prob(self, value):
normalize_term = np.log(1/(2*self.scale))
value_scaled = np.abs(value - self.loc) / self.scale
return -1*value_scaled + normalize_term
@property
def mean(self):
return np.broadcast_to(self.loc, self.batch_shape)
@property
def variance(self):
return np.broadcast_to(2 * self.scale ** 2, self.batch_shape)
```

def forward(x1, x2,x3, x4,x5, x6, s):

results = s * torch.sqrt(x1**2 + x2**2+ x3**2+ x4**2+ x5**2+ x6**2)

return torch.tensor([results,results,results])

def model(obs): ## just an example

x1 = pyro.sample(‘X1’, dist.Uniform(-2, 2))

x2 = pyro.sample(‘X2’, dist.Uniform(-2, 2))

x3 = pyro.sample(‘X3’, dist.Uniform(-2, 2))

x4 = pyro.sample(‘X4’, dist.Uniform(-2, 2))

x5 = pyro.sample(‘X5’, dist.Uniform(-2, 2))

x6 = pyro.sample(‘X6’, dist.Uniform(-2, 2))

s = pyro.sample(‘S’, dist.Normal(19.5, .5))

t = forward(x1, x2,x3, x4,x5, x6, s)

pyro.sample(‘obs’, dist.Normal(t, 3/2), obs=obs)

if **name** == ‘**main**’:

```
kernel = NUTS(model)
mcmc = MCMC(kernel,num_samples=50, warmup_steps=5)
mcmc.run(torch.tensor([[19,19,19],[20,20,20],[19,19,19]]).to(torch.float32))
from pyro.infer import MCMC, NUTS, Predictive
sample_list = mcmc.get_samples()
cc = Predictive(model, mcmc.get_samples())(torch.tensor([[21,21,21],[21,21,21],[21,21,21]]).to(torch.float32))
```

###_____________###