Following up with some data to try to set up a concrete example.
Dependencies:
from math import inf
import networkx as nx
import numpyro # 0.9.2
from jax import numpy as jnp, random
from numpyro import distributions as dist, sample
from numpyro.infer.mcmc import MCMC
from numpyro.infer.hmc import NUTS
from copy import deepcopy
def graph(states, edges):
g = nx.DiGraph()
g.add_nodes_from(list(states.keys()))
nx.set_node_attributes(g, states, name='states')
g.add_edges_from(edges)
return g
Data:
simkeys = random.split(random.PRNGKey(0), 10)
nsim = 5_000
p = {
'Z': jnp.array([2/3, 1/3]),
'A': jnp.array([[1 - 0.87, 0.87], [1 - 0.43, 0.43]]),
'B': jnp.array([[0.78, 1 - 0.78], [0.46, 1 - 0.46]]),
'Y': jnp.array([[[0.78, 1 - 0.78],
[0.46, 1 - 0.46]],
[[0.49, 1 - 0.78],
[0.61, 1 - 0.46]]]),
}
g = graph(states={n: [0, 1] for n in ['Z', 'A', 'B', 'Y']},
edges =[('Z', 'A'), ('Z', 'B'), ('A', 'Y'), ('B', 'Y')])
gg = deepcopy(g); gg.remove_nodes_from('Y')
Z = random.categorical(simkeys[0], p['Z'], shape=(nsim, ))
A = random.categorical(simkeys[1], p['A'][Z], shape=(nsim, ))
B = random.categorical(simkeys[2], p['B'][Z], shape=(nsim, ))
Y = random.categorical(simkeys[3], p['Y'][A, B], shape=(nsim, ))
priors = {
'Z': jnp.array([0.5, 0.5]),
'A': jnp.array([[0.5, 0.5], [0.5, 0.5]]),
'B': jnp.array([[0.5, 0.5], [0.5, 0.5]]),
'Y': jnp.array([[[0.5, 0.5],
[0.5, 0.5]],
[[0.5, 0.5],
[0.5, 0.5]]]),
}
I refactored the model after stumbling upon a few ‘Missing a plate statement for batch dimension’ errors, that’s what I came up with:
def model(graph: nx.DiGraph,
data: {str: jnp.array},
hyperparameters: {str: jnp.array}):
g = deepcopy(graph)
n = list(data.values())[0].size
p = dict()
for node in list(nx.topological_sort(g)):
predecessors = list(g.predecessors(node))
if predecessors:
sizes = [len(g.nodes[nn]['states']) # TODO: add recursive step for generic graphs
for nn in nx.ancestors(g, node)
if g.nodes[nn]['latent']]
if sizes:
with numpyro.plate_stack(f'plate[{node}]', sizes=sizes):
p[node] = sample(f'p[{node}]', dist.Dirichlet(hyperparameters[node]))
else:
p[node] = sample(f'p[{node}]', dist.Dirichlet(hyperparameters[node]))
arr = jnp.column_stack((data[predecessor]
for predecessor
in predecessors))
probs = jnp.apply_along_axis(lambda conditional: p[node][tuple(conditional)],
axis=1,
arr=arr)
else:
p[node] = sample(f'p[{node}]', dist.Dirichlet(hyperparameters[node]))
probs = jnp.ones( (n, len(g.nodes[node]['states'])) ) * p[node]
with numpyro.plate('data', size=n):
if node in data:
sample(f'obs[{node}]', dist.Categorical(probs=probs), obs=data[node])
nx.set_node_attributes(g, {node: False}, name='latent')
else:
data[node] = sample(f'latent[{node}]',
dist.Categorical(probs=probs),
infer={'enumerate': 'parallel'})
nx.set_node_attributes(g, {node: True}, name='latent')
which yields (how do I read that message?):
from numpyro import handlers
trace = (numpyro
.contrib
.funsor
.enum_messenger
.trace(handlers.seed(model, simkeys[0]))
.get_trace(graph=gg, data={'A': A, 'B': B, }, hyperparameters=priors, ))
print(numpyro.util.format_shapes(trace))
# Trace Shapes:
# Param Sites:
# Sample Sites:
# p[Z] dist | 2
# value | 2
# latent[Z] dist 5000 |
# value 5000 |
# p[A] dist 2 | 2
# value 2 | 2
# obs[A] dist 5000 |
# value 5000 |
# p[B] dist 2 | 2
# value 2 | 2
# obs[B] dist 5000 |
# value 5000 |
The model runs perfectly fine if all data is observed:
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=200, num_samples=300)
mcmc.run(random.PRNGKey(1), graph=gg, data={'A': A, 'B': B, 'Z': Z}, hyperparameters=priors)
But if 'Z'
is not in the dictionary passed to the data
argument, the following error is raised:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Input In [20], in <cell line: 4>()
1 kernel = NUTS(model)
2 mcmc = MCMC(kernel, num_warmup=200, num_samples=300)
----> 4 mcmc.run(random.PRNGKey(1), graph=gg, data={'A': A, 'B': B}, hyperparameters=priors)
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/mcmc.py:593, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
591 map_args = (rng_key, init_state, init_params)
592 if self.num_chains == 1:
--> 593 states_flat, last_state = partial_map_fn(map_args)
594 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
595 else:
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
701 # vectorized
702 else:
703 rng_key, rng_key_init_model = jnp.swapaxes(
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
710 raise ValueError(
711 "Valid value of `init_params` must be provided with" " `potential_fn`."
712 )
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
655 dynamic_args=True,
656 init_strategy=self._init_strategy,
657 model_args=model_args,
658 model_kwargs=model_kwargs,
659 forward_mode_differentiation=self._forward_mode_differentiation,
660 )
661 if self._init_fn is None:
662 self._init_fn, self._sample_fn = hmc(
663 potential_fn_gen=potential_fn,
664 kinetic_fn=self._kinetic_fn,
665 algo=self._algo,
666 )
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
652 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
655 rng_key,
656 substitute(
657 model,
658 data={
659 k: site["value"]
660 for k, site in model_trace.items()
661 if site["type"] in ["plate"]
662 },
663 ),
664 init_strategy=init_strategy,
665 enum=has_enumerate_support,
666 model_args=model_args,
667 model_kwargs=model_kwargs,
668 prototype_params=prototype_params,
669 forward_mode_differentiation=forward_mode_differentiation,
670 validate_grad=validate_grad,
671 )
673 if not_jax_tracer(is_valid):
674 if device_get(~jnp.all(is_valid)):
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
393 # Handle possible vectorization
394 if rng_key.ndim == 1:
--> 395 (init_params, pe, z_grad), is_valid = _find_valid_params(
396 rng_key, exit_early=True
397 )
398 else:
399 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False)
378 if exit_early and not_jax_tracer(rng_key):
379 # Early return if valid params found. This is only helpful for single chain,
380 # where we can avoid compiling body_fn in while_loop.
--> 381 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
382 if not_jax_tracer(is_valid):
383 if device_get(is_valid):
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params.<locals>.body_fn(state)
364 z_grad = jacfwd(potential_fn)(params)
365 else:
--> 366 pe, z_grad = value_and_grad(potential_fn)(params)
367 z_grad_flat = ravel_pytree(z_grad)[0]
368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 8 frame]
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
244 substituted_model = substitute(
245 model, substitute_fn=partial(_unconstrain_reparam, params)
246 )
247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density_(
249 substituted_model, model_args, model_kwargs, {}
250 )
251 return -log_joint
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:270, in log_density(model, model_args, model_kwargs, params)
249 def log_density(model, model_args, model_kwargs, params):
250 """
251 Similar to :func:`numpyro.infer.util.log_density` but works for models
252 with discrete latent variables. Internally, this uses :mod:`funsor`
(...)
268 :return: log of joint density and a corresponding model trace
269 """
--> 270 result, model_trace, _ = _enum_log_density(
271 model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
272 )
273 return result.data, model_trace
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:159, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
157 model = substitute(model, data=params)
158 with plate_to_enum_plate():
--> 159 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
160 log_factors = []
161 time_to_factors = defaultdict(list) # log prob factors
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (4 times)]
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Input In [4], in model(graph, data, hyperparameters)
34 p[node] = sample(f'p[{node}]', dist.Dirichlet(hyperparameters[node]))
35 probs = jnp.ones( (n, len(g.nodes[node]['states'])) ) * p[node]
---> 37 with numpyro.plate('data', size=n):
38 if node in data:
39 sample(f'obs[{node}]', dist.Categorical(probs=probs), obs=data[node])
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:506, in plate.__enter__(self)
500 super().__enter__() # do this first to take care of globals recycling
501 name_to_dim = (
502 OrderedDict([(self.name, self.dim)])
503 if self.dim is not None
504 else OrderedDict()
505 )
--> 506 indices = to_data(
507 self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE
508 )
509 # extract the dimension allocated by to_data to match plate's current behavior
510 self.dim, self.indices = -len(indices.shape), indices.squeeze()
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:709, in to_data(x, name_to_dim, dim_type)
696 name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim
698 initial_msg = {
699 "type": "to_data",
700 "fn": lambda x, name_to_dim, dim_type: funsor.to_data(
(...)
706 "mask": None,
707 }
--> 709 msg = apply_stack(initial_msg)
710 return msg["value"]
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/primitives.py:59, in apply_stack(msg)
55 # A Messenger that sets msg["stop"] == True also prevents application
56 # of postprocess_message by Messengers above it on the stack
57 # via the pointer variable from the process_message loop
58 for handler in _PYRO_STACK[-pointer - 1 :]:
---> 59 handler.postprocess_message(msg)
60 return msg
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:528, in plate.postprocess_message(self, msg)
526 def postprocess_message(self, msg):
527 if msg["type"] in ["to_funsor", "to_data"]:
--> 528 return super().postprocess_message(msg)
529 # NB: copied literally from original plate messenger, with self._indices is replaced
530 # by self.indices
531 if msg["type"] in ("subsample", "param") and self.dim is not None:
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:417, in GlobalNamedMessenger.postprocess_message(self, msg)
415 self._pyro_post_to_funsor(msg)
416 elif msg["type"] == "to_data":
--> 417 self._pyro_post_to_data(msg)
File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:430, in GlobalNamedMessenger._pyro_post_to_data(self, msg)
427 if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
428 for name in msg["args"][0].inputs:
429 self._saved_globals += (
--> 430 (name, _DIM_STACK.global_frame.name_to_dim[name]),
431 )
KeyError: 'data'