Hi all,
I cannot seem to successfully run the introductory code from Getting Started with NumPyro — NumPyro documentation. I’m sorry if there is a very basic solution to this, but could anyone help? Everytime I get the error: TypeError: Argument ‘None’ of type ‘<class ‘NoneType’>’ is not a valid JAX type
import numpy as np
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools(J, sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(0, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
with numpyro.plate('J', J):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
num_warmup, num_samples = 1000, 2000
# Run NUTS.
kernel = NUTS(eight_schools)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key, J, sigma, y=y)
mcmc.print_summary()
samples_1 = mcmc.get_samples()
Error message:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-81-491bcece6fbc> in <module>
19 kernel = NUTS(eight_schools)
20 mcmc = MCMC(kernel, num_warmup, num_samples)
---> 21 mcmc.run(rng_key, J, sigma, y=y)
22 mcmc.print_summary()
23 samples_1 = mcmc.get_samples()
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
496 map_args = (rng_key, init_state, init_params)
497 if self.num_chains == 1:
--> 498 states_flat, last_state = partial_map_fn(map_args)
499 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
500 else:
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
331 rng_key, init_state, init_params = init
332 if init_state is None:
--> 333 init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
334 model_args=args, model_kwargs=kwargs)
335 sample_fn, postprocess_fn = self._get_cached_fns()
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
503 )
504 if rng_key.ndim == 1:
--> 505 init_state = hmc_init_fn(init_params, rng_key)
506 else:
507 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/hmc.py in <lambda>(init_params, rng_key)
486 ' `potential_fn`.')
487
--> 488 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
489 init_params,
490 num_warmup=num_warmup,
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/numpyro/infer/hmc.py in init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, model_args, model_kwargs, rng_key)
209 """
210 step_size = lax.convert_element_type(step_size, jnp.result_type(float))
--> 211 trajectory_length = lax.convert_element_type(trajectory_length, jnp.result_type(float))
212 nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad
213 forward_mode_ad = forward_mode_differentiation
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/_src/lax/lax.py in convert_element_type(operand, new_dtype)
423 if hasattr(operand, '__jax_array__'):
424 operand = operand.__jax_array__()
--> 425 return _convert_element_type(operand, new_dtype, weak_type=False)
426
427 def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/_src/lax/lax.py in _convert_element_type(operand, new_dtype, weak_type)
452 return operand
453 else:
--> 454 return convert_element_type_p.bind(operand, new_dtype=new_dtype,
455 weak_type=new_weak_type)
456
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **params)
262 args, used_axis_names(self, params) if self._dispatch_on_params else None)
263 tracers = map(top_trace.full_raise, args)
--> 264 out = top_trace.process_primitive(self, tracers, params)
265 return map(full_lower, out) if self.multiple_results else full_lower(out)
266
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params)
601
602 def process_primitive(self, primitive, tracers, params):
--> 603 return primitive.impl(*tracers, **params)
604
605 def process_call(self, primitive, f, tracers, params):
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
246 def apply_primitive(prim, *args, **params):
247 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 248 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
249 return compiled_fun(*args)
250
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/interpreters/xla.py in arg_spec(x)
238
239 def arg_spec(x: Any) -> ArgSpec:
--> 240 aval = abstractify(x)
241 try:
242 return aval, x._device
/anaconda/envs/azureml_py38/lib/python3.8/site-packages/jax/interpreters/xla.py in abstractify(x)
184 if hasattr(x, '__jax_array__'):
185 return abstractify(x.__jax_array__())
--> 186 raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
187
188 def _make_abstract_python_scalar(typ, val):
TypeError: Argument 'None' of type '<class 'NoneType'>' is not a valid JAX type
Thank you!