I currently have a very complex model, which needs to input 6 variables to generate a vector. I tried to use pyro to solve him. However, the problem I have now is that when I train the mcmc model and try to predict new data, I cannot get the inversion results (i.e., the values of six variables in the new validation set). Do you know how to solve this problem?
###_____________###
import numpy as onp
from scipy.optimize import minimize
from scipy.stats import gaussian_kde
import jax.numpy as np
from jax import random, lax
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import torch
NUM_WARMUP, NUM_SAMPLES = 1000, 5
class Laplace(dist.Distribution):
arg_constraints = {‘loc’: dist.constraints.real, ‘scale’: dist.constraints.positive}
support = dist.constraints.real
reparametrized_params = [‘loc’, ‘scale’]
def __init__(self, loc=0., scale=1., validate_args=None):
self.loc, self.scale = dist.util.promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
eps = random.laplace(key, shape=sample_shape + self.batch_shape + self.event_shape)
return self.loc + eps * self.scale
def log_prob(self, value):
normalize_term = np.log(1/(2*self.scale))
value_scaled = np.abs(value - self.loc) / self.scale
return -1*value_scaled + normalize_term
@property
def mean(self):
return np.broadcast_to(self.loc, self.batch_shape)
@property
def variance(self):
return np.broadcast_to(2 * self.scale ** 2, self.batch_shape)
def forward(x1, x2,x3, x4,x5, x6, s):
results = s * torch.sqrt(x12 + x22+ x32+ x42+ x52+ x62)
return torch.tensor([results,results,results])
def model(obs): ## just an example
x1 = pyro.sample(‘X1’, dist.Uniform(-2, 2))
x2 = pyro.sample(‘X2’, dist.Uniform(-2, 2))
x3 = pyro.sample(‘X3’, dist.Uniform(-2, 2))
x4 = pyro.sample(‘X4’, dist.Uniform(-2, 2))
x5 = pyro.sample(‘X5’, dist.Uniform(-2, 2))
x6 = pyro.sample(‘X6’, dist.Uniform(-2, 2))
s = pyro.sample(‘S’, dist.Normal(19.5, .5))
t = forward(x1, x2,x3, x4,x5, x6, s)
pyro.sample(‘obs’, dist.Normal(t, 3/2), obs=obs)
if name == ‘main’:
kernel = NUTS(model)
mcmc = MCMC(kernel,num_samples=50, warmup_steps=5)
mcmc.run(torch.tensor([[19,19,19],[20,20,20],[19,19,19]]).to(torch.float32))
from pyro.infer import MCMC, NUTS, Predictive
sample_list = mcmc.get_samples()
cc = Predictive(model, mcmc.get_samples())(torch.tensor([[21,21,21],[21,21,21],[21,21,21]]).to(torch.float32))
###_____________###