 # Pyro Example: importance sampling port to NumPyro

Hello All,

I am trying to port the Pyro importance sampling over to NumPyro for my own study, but I am having a problem with getting the model to run. I don’t think this is so much a NumPyro difficulty but a Jax issue (I think?)

When I run the code, my computer just spins using 100% CPU for the core yet, there is no progress bar so I don’t think the model has started sampling. I have successfully run the importance sampling with Pyro and noticed the tutorial took a very long time to run, but that might be due to the CPU-only install of Pyro and PyTorch.

I feel as if the tutorial is a relatively easy inference problem, but the “simulation” part of the code below, named `jax_simulate` is what is bogging down the model.

Here is the link for the code repo, but below is the relevant code:

``````import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import matplotlib.pyplot as plt
import arviz as az

import numpyro
from numpyro.distributions import Normal, Uniform, LogNormal
from numpyro.infer import MCMC, NUTS

key = jax.random.PRNGKey(2)

# +
def _body(info):
displacement, length, velocity, dt, acceleration, T = info
displacement += velocity * dt
velocity += acceleration * dt
T += dt

return displacement, length, velocity, dt, acceleration, T

def _conf(info):
displacement, length, _, _, _, _ = info
return displacement < length

def slide(displacement, length, velocity, dt, acceleration, T):
info = (displacement, length, velocity, dt, acceleration, T)
res = jax.lax.while_loop(_conf, _body, info)
return res[-1]

def jax_simulate(mu, key, noise_sigma, length=2.0, phi=jnp.pi / 6.0, dt=0.005):
T = jnp.zeros(())
velocity = jnp.zeros(())
displacement = jnp.zeros(())
acceleration = (little_g * jnp.sin(phi)) - (little_g * jnp.cos(phi)) * mu

T = slide(displacement, length, velocity, dt, acceleration, T)

return T + noise_sigma * jax.random.normal(key, ())

N_obs = 20

keys = jax.random.split(key, N_obs)

observed_data = jnp.array([jax.jit(jax_simulate)(mu0, key, time_measurement_sigma) for key in keys])
observed_mean = jnp.mean(observed_data)

def numpyro_model(observed_data, measurment_sigma, length=2.0, phi=jnp.pi / 6.0, dt=0.005):
mu = numpyro.sample("mu", LogNormal(0, 0.5))

with numpyro.plate("data_loop", len(observed_data)):
T = jnp.zeros(())
velocity = jnp.zeros(())
displacement = jnp.zeros(())
acceleration = (little_g * jnp.sin(phi)) - (little_g * jnp.cos(phi)) * mu
T_simulated = slide(displacement, length, velocity, dt, acceleration, T)
numpyro.sample("obs", Normal(T_simulated, measurment_sigma), obs=observed_data)

return mu

numpyro.render_model(numpyro_model, model_args=(observed_data,time_measurement_sigma), render_distributions=True)

nuts_kernel = NUTS(numpyro_model)

mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=10, num_chains=1, chain_method='parallel', progress_bar=True)

mcmc.run(key, observed_data, time_measurement_sigma, extra_fields=('potential_energy',))
``````
1 Like

I think that your `while_loop` is not terminated. You might want to make a custom while_loop to debug.

``````from numpyro.util import while_loop, optional, control_flow_prims_disabled

with optional(False, control_flow_prims_disabled()):
# change False -> True to turn while_loop into a Python loop
res = while_loop(_conf, _body, info)
``````

In addition, you’ll need to enable forward_mode_differentiation because your model has `while_loop`.

Ah, yes, I will iterate and refactor, I am still new to Jax so I still need to wrap my head around the functional style.

The while loop doesn’t look infinite to me =/, could be overlooking something obvious since it’s late.

When I add `forward_mode_differentiation` to `nuts_kernel = NUTS(numpyro_model,forward_mode_differentiation=True)`

Still results in the same scenario, 100% CPU usage, no output from `mcmc.run()`

Below is the while debug snippet.

``````import jax.numpy as jnp
from numpyro.util import while_loop, optional, control_flow_prims_disabled

little_g = 9.8  # m/s/s
mu0 = 0.12  # actual coefficient of friction in the experiment

def _body(info):
displacement, length, velocity, dt, acceleration, T = info
displacement += velocity * dt
velocity += acceleration * dt
T += dt
print("_body | displacement = %s | length = %s | velocity = %s | acceleration = %s" %(displacement, length, velocity, acceleration))

return displacement, length, velocity, dt, acceleration, T

def _conf(info):
displacement, length, velocity, dt, acceleration, T = info
print("_conf | displacement = %s | length = %s | velocity = %s | acceleration = %s" %(displacement, length, velocity, acceleration))
return displacement < length

with optional(True, control_flow_prims_disabled()):
phi=jnp.pi / 6.0
length = 2.0
dt = 0.5
T = jnp.zeros(())
velocity = jnp.zeros(())
displacement = jnp.zeros(())
acceleration = (little_g * jnp.sin(phi)) - (little_g * jnp.cos(phi)) * mu0
info = (displacement, length, velocity, dt, acceleration, T)
res = while_loop(_conf, _body, info)
print(res)
print(len(res))

_conf | displacement = 0.0 | length = 2.0 | velocity = 0.0 | acceleration = 3.8815541
_body | displacement = 0.0 | length = 2.0 | velocity = 1.9407771 | acceleration = 3.8815541
_conf | displacement = 0.0 | length = 2.0 | velocity = 1.9407771 | acceleration = 3.8815541
_body | displacement = 0.97038853 | length = 2.0 | velocity = 3.8815541 | acceleration = 3.8815541
_conf | displacement = 0.97038853 | length = 2.0 | velocity = 3.8815541 | acceleration = 3.8815541
_body | displacement = 2.9111657 | length = 2.0 | velocity = 5.8223314 | acceleration = 3.8815541
_conf | displacement = 2.9111657 | length = 2.0 | velocity = 5.8223314 | acceleration = 3.8815541
(DeviceArray(2.9111657, dtype=float32), 2.0, DeviceArray(5.8223314, dtype=float32), 0.5, DeviceArray(3.8815541, dtype=float32), DeviceArray(1.5, dtype=float32))
6

``````

I had a little more time to debug, and there was an infinite while loop due to missing the case when acceleration is negative due to the box not moving because the friction is large … I fixed this issue using `jax.lax.cond` to return a `1.0e5` as in the original example from Pyro. You can see the port in the notebook found in the repo.

I did notice a possible concern with a very low `n_eff`, but the `r_hat` was 1.00 for both priors, and I am not sure what’s happening. The trace looks okay, I think? I am still new to debugging HMC.

Yeah, glad that you found the issue. For the low `n_eff`, I think you can increase `max_tree_depth` parameter of NUTS but I guess it is fine. You might want to set `num_chains` to 4 to be more confident that your chains converge to the true posterior. The trace shows that the samples are somehow correlated, which explains why `n_eff` is ~100. I guess setting `max_tree_depth` to 13 or 15 can remedy that issue.

i don’t think the model log density is differentiable. it probably makes more sense to use SA: Markov Chain Monte Carlo (MCMC) — NumPyro documentation

Using the Uniform(0., 1.) prior using `{num_warmup: 500, num_samples: 1000, num_chains: 4}`:

`max_tree_depth` `n_eff` `r_hat`
10 547.18 1.00
13 406.83 1.01
15 462.92 1.01

Using the Gamma(2., 2.,) prior `{num_warmup: 500, num_samples: 1000, num_chains: 4}`:

`max_tree_depth` `n_eff` `r_hat`
10 2.63 4.39
13 354.46 1.01
15 521.95 1.01

When you include more chains, `n_eff` and `r_hat` summarize all the chains? The `n_eff` even for `max_depth_tree = 15` feels low, but I don’t have a good sense right now. I am always tempted to convert `n_eff` into a percentage out of total `num_samples`, but I think this is wrong? because `n_eff` can be greater than the `num_samples`.

The chain using NUTS for both priors using the largest `max_tree_depth` seem good:

Params - `{prior: Uniform(0, 1), kernel: NUTS, max_tree_depth: 15, num_warmup: 500, num_samples: 1000, num_chains: 4}`

Params - `{prior: Gamma(2, 2), kernel: NUTS, max_tree_depth: 15, num_warmup: 500, num_samples: 1000, num_chains: 4}`

I tried out the SA kernel, and I get some odd jumps between chains. The prior choice is interesting because I thought Uniform(0., 1.) is entirely uninformative, but using this prior for the coefficient of friction seems to actually constrain the value more than my Gamma prior. This is obvious in hindsight, haha. The posterior `mu` using either prior are still way off using SA.

Params - `{prior: Uniform(0, 1), kernel: SA, num_warmup: 2000, num_samples: 4000, num_chains: 10}` ``````the coefficient of friction inferred by pyro is 0.129 +- 0.010
the coefficient of friction inferred by pyro is 0.799 +- 0.000
the coefficient of friction inferred by pyro is 0.951 +- 0.000
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 0.127 +- 0.004
the coefficient of friction inferred by pyro is 0.273 +- 0.022
the coefficient of friction inferred by pyro is 0.125 +- 0.004
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 0.543 +- 0.000
the coefficient of friction inferred by pyro is 0.398 +- 0.004

mean       std    median      5.0%     95.0%     n_eff     r_hat
mu      0.36      0.29      0.20      0.12      0.95      5.01     54.65

Number of divergences: 5
``````

Params - `{prior: Gamma(2, 2), kernel: SA, num_warmup: 2000, num_samples: 4000, num_chains: 10}` ``````the coefficient of friction inferred by pyro is 0.181 +- 0.031
the coefficient of friction inferred by pyro is 3.983 +- 0.000
the coefficient of friction inferred by pyro is 19.347 +- 0.000
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 0.127 +- 0.004
the coefficient of friction inferred by pyro is 1.114 +- 0.000
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 1.747 +- 0.000
the coefficient of friction inferred by pyro is 0.995 +- 0.000

mean       std    median      5.0%     95.0%     n_eff     r_hat
mu      2.79      5.64      0.62      0.12     19.35      5.00    882.02

Number of divergences: 0
``````

@bdatko SA generally requires a lot of samples, e.g. 500k. if you use ` progress_bar=False` it should nevertheless be fast

incidentally for what it’s worth the more elegant way to do this would be to use odeint:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.odeint.html

the pyro code basically implements a rudimentary ode solver

• If `odeint` is used, then no need for setting `forward_mode_differentiation=True` because `grad` is implemented for `odeint` in JAX. See predator-prey example.
• If `cond` and `while_loop` is used, then `forward_mode_differentiation=True` is needed. Be aware of the discontinuity that happens at `acceleration=0`.
1 Like