Pyro Example: importance sampling port to NumPyro

Hello All,

I am trying to port the Pyro importance sampling over to NumPyro for my own study, but I am having a problem with getting the model to run. I don’t think this is so much a NumPyro difficulty but a Jax issue (I think?)

When I run the code, my computer just spins using 100% CPU for the core yet, there is no progress bar so I don’t think the model has started sampling. I have successfully run the importance sampling with Pyro and noticed the tutorial took a very long time to run, but that might be due to the CPU-only install of Pyro and PyTorch.

I feel as if the tutorial is a relatively easy inference problem, but the “simulation” part of the code below, named jax_simulate is what is bogging down the model.

Here is the link for the code repo, but below is the relevant code:

import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import matplotlib.pyplot as plt
import arviz as az

import numpyro
from numpyro.distributions import Normal, Uniform, LogNormal
from numpyro.infer import MCMC, NUTS

key = jax.random.PRNGKey(2)

# +
def _body(info):
    displacement, length, velocity, dt, acceleration, T = info
    displacement += velocity * dt
    velocity += acceleration * dt
    T += dt
    
    return displacement, length, velocity, dt, acceleration, T

def _conf(info):
    displacement, length, _, _, _, _ = info
    return displacement < length

def slide(displacement, length, velocity, dt, acceleration, T):
    info = (displacement, length, velocity, dt, acceleration, T)
    res = jax.lax.while_loop(_conf, _body, info)
    return res[-1]

def jax_simulate(mu, key, noise_sigma, length=2.0, phi=jnp.pi / 6.0, dt=0.005):
    T = jnp.zeros(())
    velocity = jnp.zeros(())
    displacement = jnp.zeros(())
    acceleration = (little_g * jnp.sin(phi)) - (little_g * jnp.cos(phi)) * mu
    
    T = slide(displacement, length, velocity, dt, acceleration, T)

    return T + noise_sigma * jax.random.normal(key, ())

N_obs = 20

keys = jax.random.split(key, N_obs)

observed_data = jnp.array([jax.jit(jax_simulate)(mu0, key, time_measurement_sigma) for key in keys])
observed_mean = jnp.mean(observed_data)

def numpyro_model(observed_data, measurment_sigma, length=2.0, phi=jnp.pi / 6.0, dt=0.005):
    mu = numpyro.sample("mu", LogNormal(0, 0.5))
    
    with numpyro.plate("data_loop", len(observed_data)):
        T = jnp.zeros(())
        velocity = jnp.zeros(())
        displacement = jnp.zeros(())
        acceleration = (little_g * jnp.sin(phi)) - (little_g * jnp.cos(phi)) * mu
        T_simulated = slide(displacement, length, velocity, dt, acceleration, T)
        numpyro.sample("obs", Normal(T_simulated, measurment_sigma), obs=observed_data)
        
    return mu


numpyro.render_model(numpyro_model, model_args=(observed_data,time_measurement_sigma), render_distributions=True)

nuts_kernel = NUTS(numpyro_model)

mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=10, num_chains=1, chain_method='parallel', progress_bar=True)

mcmc.run(key, observed_data, time_measurement_sigma, extra_fields=('potential_energy',))
1 Like

I think that your while_loop is not terminated. You might want to make a custom while_loop to debug.

from numpyro.util import while_loop, optional, control_flow_prims_disabled

with optional(False, control_flow_prims_disabled()):
    # change False -> True to turn while_loop into a Python loop
    res = while_loop(_conf, _body, info)

In addition, you’ll need to enable forward_mode_differentiation because your model has while_loop.

Ah, yes, I will iterate and refactor, I am still new to Jax so I still need to wrap my head around the functional style.

The while loop doesn’t look infinite to me =/, could be overlooking something obvious since it’s late.

When I add forward_mode_differentiation to nuts_kernel = NUTS(numpyro_model,forward_mode_differentiation=True)

Still results in the same scenario, 100% CPU usage, no output from mcmc.run()

Below is the while debug snippet.

import jax.numpy as jnp
from numpyro.util import while_loop, optional, control_flow_prims_disabled

little_g = 9.8  # m/s/s
mu0 = 0.12  # actual coefficient of friction in the experiment

def _body(info):
    displacement, length, velocity, dt, acceleration, T = info
    displacement += velocity * dt
    velocity += acceleration * dt
    T += dt
    print("_body | displacement = %s | length = %s | velocity = %s | acceleration = %s" %(displacement, length, velocity, acceleration))
    
    return displacement, length, velocity, dt, acceleration, T

def _conf(info):
    displacement, length, velocity, dt, acceleration, T = info
    print("_conf | displacement = %s | length = %s | velocity = %s | acceleration = %s" %(displacement, length, velocity, acceleration))
    return displacement < length

with optional(True, control_flow_prims_disabled()):
    phi=jnp.pi / 6.0
    length = 2.0
    dt = 0.5
    T = jnp.zeros(())
    velocity = jnp.zeros(())
    displacement = jnp.zeros(())
    acceleration = (little_g * jnp.sin(phi)) - (little_g * jnp.cos(phi)) * mu0
    info = (displacement, length, velocity, dt, acceleration, T)
    res = while_loop(_conf, _body, info)
    print(res)
    print(len(res))

_conf | displacement = 0.0 | length = 2.0 | velocity = 0.0 | acceleration = 3.8815541
_body | displacement = 0.0 | length = 2.0 | velocity = 1.9407771 | acceleration = 3.8815541
_conf | displacement = 0.0 | length = 2.0 | velocity = 1.9407771 | acceleration = 3.8815541
_body | displacement = 0.97038853 | length = 2.0 | velocity = 3.8815541 | acceleration = 3.8815541
_conf | displacement = 0.97038853 | length = 2.0 | velocity = 3.8815541 | acceleration = 3.8815541
_body | displacement = 2.9111657 | length = 2.0 | velocity = 5.8223314 | acceleration = 3.8815541
_conf | displacement = 2.9111657 | length = 2.0 | velocity = 5.8223314 | acceleration = 3.8815541
(DeviceArray(2.9111657, dtype=float32), 2.0, DeviceArray(5.8223314, dtype=float32), 0.5, DeviceArray(3.8815541, dtype=float32), DeviceArray(1.5, dtype=float32))
6

I had a little more time to debug, and there was an infinite while loop due to missing the case when acceleration is negative due to the box not moving because the friction is large … :grimacing:

I fixed this issue using jax.lax.cond to return a 1.0e5 as in the original example from Pyro. You can see the port in the notebook found in the repo.

I did notice a possible concern with a very low n_eff, but the r_hat was 1.00 for both priors, and I am not sure what’s happening. The trace looks okay, I think? I am still new to debugging HMC.

Yeah, glad that you found the issue. For the low n_eff, I think you can increase max_tree_depth parameter of NUTS but I guess it is fine. You might want to set num_chains to 4 to be more confident that your chains converge to the true posterior. The trace shows that the samples are somehow correlated, which explains why n_eff is ~100. I guess setting max_tree_depth to 13 or 15 can remedy that issue.

i don’t think the model log density is differentiable. it probably makes more sense to use SA: Markov Chain Monte Carlo (MCMC) — NumPyro documentation

Using the Uniform(0., 1.) prior using {num_warmup: 500, num_samples: 1000, num_chains: 4}:

max_tree_depth n_eff r_hat
10 547.18 1.00
13 406.83 1.01
15 462.92 1.01

Using the Gamma(2., 2.,) prior {num_warmup: 500, num_samples: 1000, num_chains: 4}:

max_tree_depth n_eff r_hat
10 2.63 4.39
13 354.46 1.01
15 521.95 1.01

When you include more chains, n_eff and r_hat summarize all the chains? The n_eff even for max_depth_tree = 15 feels low, but I don’t have a good sense right now. I am always tempted to convert n_eff into a percentage out of total num_samples, but I think this is wrong? because n_eff can be greater than the num_samples.

The chain using NUTS for both priors using the largest max_tree_depth seem good:

Params - {prior: Uniform(0, 1), kernel: NUTS, max_tree_depth: 15, num_warmup: 500, num_samples: 1000, num_chains: 4}

Params - {prior: Gamma(2, 2), kernel: NUTS, max_tree_depth: 15, num_warmup: 500, num_samples: 1000, num_chains: 4}

I tried out the SA kernel, and I get some odd jumps between chains. The prior choice is interesting because I thought Uniform(0., 1.) is entirely uninformative, but using this prior for the coefficient of friction seems to actually constrain the value more than my Gamma prior. This is obvious in hindsight, haha. The posterior mu using either prior are still way off using SA.

Params - {prior: Uniform(0, 1), kernel: SA, num_warmup: 2000, num_samples: 4000, num_chains: 10}
image

the coefficient of friction inferred by pyro is 0.129 +- 0.010
the coefficient of friction inferred by pyro is 0.799 +- 0.000
the coefficient of friction inferred by pyro is 0.951 +- 0.000
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 0.127 +- 0.004
the coefficient of friction inferred by pyro is 0.273 +- 0.022
the coefficient of friction inferred by pyro is 0.125 +- 0.004
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 0.543 +- 0.000
the coefficient of friction inferred by pyro is 0.398 +- 0.004

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      0.36      0.29      0.20      0.12      0.95      5.01     54.65

Number of divergences: 5

Params - {prior: Gamma(2, 2), kernel: SA, num_warmup: 2000, num_samples: 4000, num_chains: 10}
image

the coefficient of friction inferred by pyro is 0.181 +- 0.031
the coefficient of friction inferred by pyro is 3.983 +- 0.000
the coefficient of friction inferred by pyro is 19.347 +- 0.000
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 0.127 +- 0.004
the coefficient of friction inferred by pyro is 1.114 +- 0.000
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 0.126 +- 0.004
the coefficient of friction inferred by pyro is 1.747 +- 0.000
the coefficient of friction inferred by pyro is 0.995 +- 0.000

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      2.79      5.64      0.62      0.12     19.35      5.00    882.02

Number of divergences: 0

@bdatko SA generally requires a lot of samples, e.g. 500k. if you use progress_bar=False it should nevertheless be fast

incidentally for what it’s worth the more elegant way to do this would be to use odeint:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.odeint.html

the pyro code basically implements a rudimentary ode solver

  • If odeint is used, then no need for setting forward_mode_differentiation=True because grad is implemented for odeint in JAX. See predator-prey example.
  • If cond and while_loop is used, then forward_mode_differentiation=True is needed. Be aware of the discontinuity that happens at acceleration=0.
1 Like