Hello,
I have a question on using the “jit_compile = True” in NUTS algorithm with a numpy model.
My numpy model calculates chlorine 36 concentration given a seismic scenario attributed by pyro. Basically I detach the tensors containing the samples and cast them in numpy arrays.Then I turn the numpy array containing the resulting chlorine 36 concentration into a torch tensor.
For now, my code works fine but is very slow (it requires days to weeks to complete). I noticed that using the jit argument speed up the inferrence but returns errors and wrong results. Is there some commands to tune the jit_compile to work that way?
I hope I am being clear, thanks in advance.
i’m sorry but i don’t follow. you can’t arbitrarily mix pyro and numpy and pytorch and detach
and expect things to work so you’ll need to be much more specific.
Ok here’s the main code :
import forward_function as forward
import geometric_scaling_factors
from constants import constants
import torch
import pyro
import numpy as np
import pyro.distributions as dist
import post_process as fig
from seismic_scenario import seismic_scenario as true_scenario
from pyro.infer import MCMC, NUTS
import parameters
""" Input seismic scenario """
seismic_scenario={}
erosion_rate = 0 # Erosion rate (mm/yr)
number_of_events = 3
seismic_scenario['erosion_rate'] = erosion_rate
""" Input parameters"""
param=parameters.param()
cl36AMS = param.cl36AMS
height = param.h
Hfinal = param.Hfinal
Data = torch.tensor(cl36AMS)
""" Geometric scaling """
scaling_depth_rock, scaling_depth_coll, scaling_surf_rock, scaling_factors = geometric_scaling_factors.neutron_scaling(param, constants, number_of_events+1)
""" MCMC parameters """
pyro.set_rng_seed(50)
w_step = 50 # number of warmup
nb_sample = 100 # number of samples
""" MCMC model """
def model(obs):
ages = torch.zeros((number_of_events))
ages[0] = pyro.sample('age1', dist.Uniform(2.0, 30*1e3))
for i in range (1, number_of_events):
max_age = ages[i-1]
ages[i] = pyro.sample('age'+str(i+1), dist.Uniform(2.0, max_age))
# print('\n age', ages)
slips = torch.zeros((number_of_events))
slips[0]=pyro.sample('slip1', dist.Uniform(0.0, Hfinal-100))
for i in range (1, number_of_events-1):
max_slip=Hfinal-torch.sum(slips[0:i])
slips[i]=pyro.sample('slip'+str(i+1), dist.Uniform(0.0, max_slip))
slips[number_of_events-1]=Hfinal-torch.sum(slips)
seismic_scenario['ages'] = ages
seismic_scenario['slips'] = slips
seismic_scenario['SR'] = true_scenario['SR']
seismic_scenario['preexp'] = true_scenario['preexp']
seismic_scenario['quiescence'] = true_scenario['quiescence']
sigma=pyro.sample('sigma', dist.Uniform(0, 10000))
t = forward.mds_torch(seismic_scenario, scaling_factors, constants, parameters, 500, 200)
return pyro.sample('obs', dist.Normal(t, sigma), obs=obs)
""" usage MCMC """
kernel = NUTS(model) # chose kernel (NUTS, HMC, ...)
mcmc = MCMC(kernel, warmup_steps=w_step, num_samples=nb_sample)
mcmc.run(obs=Data)
posterior_samples = mcmc.get_samples()
print('MCMC done \n')
Inside the mds_torch
module I detach the tensors to perform calculations with numpy (which is much faster) and then return a torch tensor. My question is : is it possible to do the same thing with jit_compile=True
in the kernel (maybe with a different method)?
I tested this with synthetic data without the jit_compile
, it works but slowly, and I am looking for ways to speed up the process.
I have tried to work with torch in the mds_torch
module, but it is even slower and as of today, never reached the end of first iteration.
I hope it is clearer, thanks again.