I am trying to run a spectra modelling code under numpyro. All the input parameters for the model are stored in a dictionary then dumped into a yaml file, which is then used by the modelling code.
def model_spec(wave, flux):
vel_start = numpyro.sample("vel_start",
dist.Uniform
(
low=10000.,
high=20000.
)
)
vel_stop = numpyro.sample("vel_stop",
dist.Uniform
(
low=15000.,
high=40000.
)
)
O = numpyro.sample("O",
dist.Uniform
(
low=0.,
high=1.
)
)
## par_list is the parameter dictionary
par_list['model']['structure']['velocity']['start'] = vel_start * u.km/u.s
par_list['model']['structure']['velocity']['stop'] = vel_stop * u.km/u.s
par_list['model']['abundances']['O'] = O
with open('parameters_02.yml', 'w') as file2:
documents = yaml.dump(par_list, file2)
with numpyro.plate('data', len(wave)):
## 'run' is the spectral modelling function I am importing at the beginning
sim_01 = run('parameters_02.yml')
mu = sim_01.flux
numpyro.sample("obs", dist.Normal(mu), obs=flux)
When I run this code, I get some errors.
But my primary concern is that if such kind of dumping to a file operation of a tracer object is permitted in numpyro (since we are using jax)?