Numpyro modelling that requires printing/dumping parameters to a file

I am trying to run a spectra modelling code under numpyro. All the input parameters for the model are stored in a dictionary then dumped into a yaml file, which is then used by the modelling code.

def model_spec(wave, flux):
    
    vel_start = numpyro.sample("vel_start", 
                       dist.Uniform
                       (
                           low=10000., 
                           high=20000.
                       )
                      )
    vel_stop = numpyro.sample("vel_stop", 
                          dist.Uniform
                          (
                              low=15000., 
                              high=40000.
                          )
                         )
    O = numpyro.sample("O", 
                        dist.Uniform
                        (
                           low=0., 
                           high=1.
                        )
                       )

    ## par_list is the parameter dictionary 
    par_list['model']['structure']['velocity']['start'] = vel_start * u.km/u.s
    par_list['model']['structure']['velocity']['stop'] = vel_stop * u.km/u.s
    par_list['model']['abundances']['O'] = O

    with open('parameters_02.yml', 'w') as file2:
        documents = yaml.dump(par_list, file2)

   with numpyro.plate('data', len(wave)):        
       ## 'run' is the spectral modelling function I am importing at the beginning 
        sim_01 = run('parameters_02.yml')
        mu = sim_01.flux
        numpyro.sample("obs", dist.Normal(mu), obs=flux)

When I run this code, I get some errors.

But my primary concern is that if such kind of dumping to a file operation of a tracer object is permitted in numpyro (since we are using jax)?

I’m not sure if I understand the approach here. You might test the dump logic with jax first to see if it works under jax.jit. Typically, when compiled, the random variables are tracers that do not have concrete values. If you want to execute some CPU codes outside of jit, you can use host_callback - but I’m not sure it will work properly here because typically, we require grad computation w.r.t. random variables (which host_callback might not offer).

Thanks for replying. Basically I am trying to do something like that shown in these pymc examples:
https://www.pymc.io/projects/examples/en/latest/case_studies/blackbox_external_likelihood_numpy.html#blackbox_external_likelihood_numpy
and
https://www.pymc.io/projects/examples/en/latest/case_studies/wrapping_jax_function.html

I’m not sure what’s shown there that you mentioned. If you want to define a potential function with custom grad, you can use jax.custom_jvp.

Thanks for your suggestions!
I’m attempting Bayesian inference and parameter estimation.
My model is a blackbox model function that expects floating point numbers as parameters rather than numpyro distributions.
Also, because my model is essentially a “black box”, I have no idea what the gradients are.
So, my question is whether it is possible to use numpyro in such a situation.

We only have Sample Adaptive sampler that does not require grad computation. I think you can use it with your blackbox (through jax host_callback) potential function.

1 Like