Can't run code from introductory tutorial

Hi all,

I cannot seem to successfully run the introductory code from Getting Started with NumPyro — NumPyro documentation. I’m sorry if there is a very basic solution to this, but could anyone help? Everytime I get the error: TypeError: Argument ‘None’ of type ‘<class ‘NoneType’>’ is not a valid JAX type

import numpy as np
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

def eight_schools(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)

num_warmup, num_samples = 1000, 2000

# Run NUTS.
kernel = NUTS(eight_schools)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key, J, sigma, y=y)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-81-491bcece6fbc> in <module>
     19 kernel = NUTS(eight_schools)
     20 mcmc = MCMC(kernel, num_warmup, num_samples)
---> 21 mcmc.run(rng_key, J, sigma, y=y)
     22 mcmc.print_summary()
     23 samples_1 = mcmc.get_samples()

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    496         map_args = (rng_key, init_state, init_params)
    497         if self.num_chains == 1:
--> 498             states_flat, last_state = partial_map_fn(map_args)
    499             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    500         else:

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    331         rng_key, init_state, init_params = init
    332         if init_state is None:
--> 333             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
    334                                            model_args=args, model_kwargs=kwargs)
    335         sample_fn, postprocess_fn = self._get_cached_fns()

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    503         )
    504         if rng_key.ndim == 1:
--> 505             init_state = hmc_init_fn(init_params, rng_key)
    506         else:
    507             # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/hmc.py in <lambda>(init_params, rng_key)
    486                              ' `potential_fn`.')
    487 
--> 488         hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    489             init_params,
    490             num_warmup=num_warmup,

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/hmc.py in init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, model_args, model_kwargs, rng_key)
    209         """
    210         step_size = lax.convert_element_type(step_size, jnp.result_type(float))
--> 211         trajectory_length = lax.convert_element_type(trajectory_length, jnp.result_type(float))
    212         nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad
    213         forward_mode_ad = forward_mode_differentiation

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/_src/lax/lax.py in convert_element_type(operand, new_dtype)
    423   if hasattr(operand, '__jax_array__'):
    424     operand = operand.__jax_array__()
--> 425   return _convert_element_type(operand, new_dtype, weak_type=False)
    426 
    427 def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/_src/lax/lax.py in _convert_element_type(operand, new_dtype, weak_type)
    452     return operand
    453   else:
--> 454     return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    455                                        weak_type=new_weak_type)
    456 

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **params)
    262         args, used_axis_names(self, params) if self._dispatch_on_params else None)
    263     tracers = map(top_trace.full_raise, args)
--> 264     out = top_trace.process_primitive(self, tracers, params)
    265     return map(full_lower, out) if self.multiple_results else full_lower(out)
    266 

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params)
    601 
    602   def process_primitive(self, primitive, tracers, params):
--> 603     return primitive.impl(*tracers, **params)
    604 
    605   def process_call(self, primitive, f, tracers, params):

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    246 def apply_primitive(prim, *args, **params):
    247   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 248   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
    249   return compiled_fun(*args)
    250 

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/interpreters/xla.py in arg_spec(x)
    238 
    239 def arg_spec(x: Any) -> ArgSpec:
--> 240   aval = abstractify(x)
    241   try:
    242     return aval, x._device

/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/interpreters/xla.py in abstractify(x)
    184   if hasattr(x, '__jax_array__'):
    185     return abstractify(x.__jax_array__())
--> 186   raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
    187 
    188 def _make_abstract_python_scalar(typ, val):

TypeError: Argument 'None' of type '<class 'NoneType'>' is not a valid JAX type

Thank you!

Hi @brendan, could you install NumPyro with pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro? We will release a new version soon, which hopefully fixes the issue.

Thanks for your reply @fehiepsi ! That fixed the error I was having, but now I am getting the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-1f4bc4b68d27> in <module>
     27 # Run NUTS.
     28 kernel = NUTS(eight_schools)
---> 29 mcmc = MCMC(kernel, num_warmup, num_samples)
     30 mcmc.run(rng_key, J, sigma, y=y)
     31 mcmc.print_summary()

TypeError: __init__() takes 2 positional arguments but 4 were given

Any idea what is going on? This is my first time using NumPyro (or pyro for that matter).

With the new version, you need to use keywords num_warmup=num_warmup and num_samples=num_samples. The error says that only self and sampler are positional arguments of the MCMC constructor.

It works. Thank you!