Use jit_compile with numpy in Pyro

Hello,
I have a question on using the “jit_compile = True” in NUTS algorithm with a numpy model.
My numpy model calculates chlorine 36 concentration given a seismic scenario attributed by pyro. Basically I detach the tensors containing the samples and cast them in numpy arrays.Then I turn the numpy array containing the resulting chlorine 36 concentration into a torch tensor.
For now, my code works fine but is very slow (it requires days to weeks to complete). I noticed that using the jit argument speed up the inferrence but returns errors and wrong results. Is there some commands to tune the jit_compile to work that way?
I hope I am being clear, thanks in advance.

i’m sorry but i don’t follow. you can’t arbitrarily mix pyro and numpy and pytorch and detach and expect things to work so you’ll need to be much more specific.

Ok here’s the main code :

import forward_function as forward
import geometric_scaling_factors
from constants import constants
import torch
import pyro
import numpy as np
import pyro.distributions as dist
import post_process as fig
from seismic_scenario import seismic_scenario as true_scenario
from pyro.infer import MCMC, NUTS
import parameters


""" Input seismic scenario """
seismic_scenario={}
erosion_rate = 0 # Erosion rate (mm/yr)
number_of_events = 3
seismic_scenario['erosion_rate'] = erosion_rate

""" Input parameters"""
param=parameters.param()
cl36AMS = param.cl36AMS
height = param.h
Hfinal = param.Hfinal
Data = torch.tensor(cl36AMS)

""" Geometric scaling """
scaling_depth_rock, scaling_depth_coll, scaling_surf_rock, scaling_factors = geometric_scaling_factors.neutron_scaling(param, constants, number_of_events+1)

""" MCMC parameters """
pyro.set_rng_seed(50)
w_step = 50 # number of warmup
nb_sample = 100 # number of samples

""" MCMC model """
def model(obs):
    
    ages = torch.zeros((number_of_events))
    ages[0] = pyro.sample('age1', dist.Uniform(2.0, 30*1e3))

    for i in range (1, number_of_events):
        max_age = ages[i-1]
        ages[i] = pyro.sample('age'+str(i+1), dist.Uniform(2.0, max_age))  
    # print('\n age', ages)
    slips = torch.zeros((number_of_events))
    slips[0]=pyro.sample('slip1', dist.Uniform(0.0, Hfinal-100))
    for i in range (1, number_of_events-1):
        max_slip=Hfinal-torch.sum(slips[0:i])
        slips[i]=pyro.sample('slip'+str(i+1), dist.Uniform(0.0, max_slip))
    slips[number_of_events-1]=Hfinal-torch.sum(slips)

    seismic_scenario['ages'] = ages
    seismic_scenario['slips'] = slips
    seismic_scenario['SR'] = true_scenario['SR']
    seismic_scenario['preexp'] = true_scenario['preexp']
    seismic_scenario['quiescence'] = true_scenario['quiescence']
    sigma=pyro.sample('sigma', dist.Uniform(0, 10000)) 

    t = forward.mds_torch(seismic_scenario, scaling_factors, constants, parameters, 500, 200)
    return pyro.sample('obs', dist.Normal(t, sigma), obs=obs)

""" usage MCMC """
kernel = NUTS(model) # chose kernel (NUTS, HMC, ...)
mcmc = MCMC(kernel, warmup_steps=w_step, num_samples=nb_sample) 
mcmc.run(obs=Data)
posterior_samples = mcmc.get_samples()
print('MCMC done \n')

Inside the mds_torch module I detach the tensors to perform calculations with numpy (which is much faster) and then return a torch tensor. My question is : is it possible to do the same thing with jit_compile=True in the kernel (maybe with a different method)?

I tested this with synthetic data without the jit_compile, it works but slowly, and I am looking for ways to speed up the process.
I have tried to work with torch in the mds_torch module, but it is even slower and as of today, never reached the end of first iteration.

I hope it is clearer, thanks again.