Hello All,
I am trying to port the Pyro importance sampling over to NumPyro for my own study, but I am having a problem with getting the model to run. I don’t think this is so much a NumPyro difficulty but a Jax issue (I think?)
When I run the code, my computer just spins using 100% CPU for the core yet, there is no progress bar so I don’t think the model has started sampling. I have successfully run the importance sampling with Pyro and noticed the tutorial took a very long time to run, but that might be due to the CPU-only install of Pyro and PyTorch.
I feel as if the tutorial is a relatively easy inference problem, but the “simulation” part of the code below, named jax_simulate
is what is bogging down the model.
Here is the link for the code repo, but below is the relevant code:
import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import matplotlib.pyplot as plt
import arviz as az
import numpyro
from numpyro.distributions import Normal, Uniform, LogNormal
from numpyro.infer import MCMC, NUTS
key = jax.random.PRNGKey(2)
# +
def _body(info):
displacement, length, velocity, dt, acceleration, T = info
displacement += velocity * dt
velocity += acceleration * dt
T += dt
return displacement, length, velocity, dt, acceleration, T
def _conf(info):
displacement, length, _, _, _, _ = info
return displacement < length
def slide(displacement, length, velocity, dt, acceleration, T):
info = (displacement, length, velocity, dt, acceleration, T)
res = jax.lax.while_loop(_conf, _body, info)
return res[-1]
def jax_simulate(mu, key, noise_sigma, length=2.0, phi=jnp.pi / 6.0, dt=0.005):
T = jnp.zeros(())
velocity = jnp.zeros(())
displacement = jnp.zeros(())
acceleration = (little_g * jnp.sin(phi)) - (little_g * jnp.cos(phi)) * mu
T = slide(displacement, length, velocity, dt, acceleration, T)
return T + noise_sigma * jax.random.normal(key, ())
N_obs = 20
keys = jax.random.split(key, N_obs)
observed_data = jnp.array([jax.jit(jax_simulate)(mu0, key, time_measurement_sigma) for key in keys])
observed_mean = jnp.mean(observed_data)
def numpyro_model(observed_data, measurment_sigma, length=2.0, phi=jnp.pi / 6.0, dt=0.005):
mu = numpyro.sample("mu", LogNormal(0, 0.5))
with numpyro.plate("data_loop", len(observed_data)):
T = jnp.zeros(())
velocity = jnp.zeros(())
displacement = jnp.zeros(())
acceleration = (little_g * jnp.sin(phi)) - (little_g * jnp.cos(phi)) * mu
T_simulated = slide(displacement, length, velocity, dt, acceleration, T)
numpyro.sample("obs", Normal(T_simulated, measurment_sigma), obs=observed_data)
return mu
numpyro.render_model(numpyro_model, model_args=(observed_data,time_measurement_sigma), render_distributions=True)
nuts_kernel = NUTS(numpyro_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=10, num_chains=1, chain_method='parallel', progress_bar=True)
mcmc.run(key, observed_data, time_measurement_sigma, extra_fields=('potential_energy',))