Stochastic Bayesian Network with discrete latent variables

Hello, all!

I’m trying to implement a stochastic Bayesian Network, such that for any given node and set of ancestors’ values, a hyper-parameter (pseudo counts) defines a distribution over the parameters of the Categorical distribution P(node | ancestors’ values). Does that make sense?

The thing is, I’d like to augment that model expressiveness by including not observed/measured nodes (i.e., latent). Intuitively, I reckoned that, given priors, I needed to sample the parameters of the conditional distribution of the latent node, and then sample the node value with that parameter (the same way I did with observed data, but without the obs argument).

The thing I’m struggling with is plates (I guess). I was able to run a model with fully observed data where the sampling of the node’s values were within a plate the size of data points, and I guess that made sense because the same parameters governs the sampling of each data point. Is that rational correct?

When adding latent variables, though, I suppose additional dimensions are required due to enumeration under the hood, but I couldn’t grasp that concept.

That’s the first Pyro (and PPL) model I’m trying to code. I would greatly appreciate any feedback on the design I’m proposing, how to correctly run that model and advice on best inference algorithms for that case (if any).

Thanks in advance!

Below is the code:

def graph(states, edges):
    g = nx.DiGraph()
    g.add_nodes_from(list(states.keys()))
    nx.set_node_attributes(g, states, name='states')
    g.add_edges_from(edges)
    return g

def query(cpds, node, conditional): return cpds[node][tuple(conditional)]

def model(graph:           nx.DiGraph,
          data:            {str: jnp.array},  # {node: n shaped array}
          hyperparameters: {str: jnp.array}): # {node: multidimensional array of cpds pseudocounts}
    d = copy.deepcopy(data)
    n = list(d.values())[0].size
    p = {
        node:
        sample(f'p[{node}]', dist.Dirichlet(hp))
        for node, hp
        in hyperparameters.items()
    }

    for node in list(nx.topological_sort(graph)):
        predecessors = list(graph.predecessors(node))
        if predecessors:
            arr   = jnp.column_stack((d[predecessor]
                                      for predecessor
                                      in  predecessors))
            probs = jnp.apply_along_axis(lambda conditional: query(p, node, conditional),
                                         axis=1,
                                         arr=arr)
        else:
            probs = jnp.ones( (n, len(graph.nodes[node]['states'])) ) * p[node]

        if node in d:
            sample(f'obs[{node}]', dist.Categorical(probs=probs), obs=d[node])
        else:
            d[node] = sample(f'latent[{node}]',
                             dist.Categorical(probs=probs),
                             infer={'enumerate': 'parallel'})

Following up with some data to try to set up a concrete example.

Dependencies:

from math import inf

import networkx as nx

import numpyro # 0.9.2

from jax import numpy as jnp, random

from numpyro import distributions as dist, sample

from numpyro.infer.mcmc import MCMC
from numpyro.infer.hmc  import NUTS

from copy import deepcopy

def graph(states, edges):
    g = nx.DiGraph()
    g.add_nodes_from(list(states.keys()))
    nx.set_node_attributes(g, states, name='states')
    g.add_edges_from(edges)
    return g

Data:

simkeys = random.split(random.PRNGKey(0), 10)
nsim    = 5_000

p = {
    'Z': jnp.array([2/3, 1/3]),
    'A': jnp.array([[1 - 0.87, 0.87], [1 - 0.43, 0.43]]),
    'B': jnp.array([[0.78, 1 - 0.78], [0.46, 1 - 0.46]]),
    'Y': jnp.array([[[0.78, 1 - 0.78],
                     [0.46, 1 - 0.46]], 
                    [[0.49, 1 - 0.78],
                     [0.61, 1 - 0.46]]]),
}

g = graph(states={n: [0, 1] for n in ['Z', 'A', 'B', 'Y']},
          edges =[('Z', 'A'), ('Z', 'B'), ('A', 'Y'), ('B', 'Y')])

gg = deepcopy(g); gg.remove_nodes_from('Y')

Z = random.categorical(simkeys[0], p['Z'],       shape=(nsim, ))
A = random.categorical(simkeys[1], p['A'][Z],    shape=(nsim, ))
B = random.categorical(simkeys[2], p['B'][Z],    shape=(nsim, ))
Y = random.categorical(simkeys[3], p['Y'][A, B], shape=(nsim, ))

priors = {
    'Z': jnp.array([0.5, 0.5]),
    'A': jnp.array([[0.5, 0.5], [0.5, 0.5]]),
    'B': jnp.array([[0.5, 0.5], [0.5, 0.5]]),
    'Y': jnp.array([[[0.5, 0.5],
                     [0.5, 0.5]], 
                    [[0.5, 0.5],
                     [0.5, 0.5]]]),
}

I refactored the model after stumbling upon a few ‘Missing a plate statement for batch dimension’ errors, that’s what I came up with:

def model(graph:           nx.DiGraph,
          data:            {str: jnp.array},
          hyperparameters: {str: jnp.array}):
    g = deepcopy(graph)
    n = list(data.values())[0].size
    p = dict()

    for node in list(nx.topological_sort(g)):
        predecessors = list(g.predecessors(node))

        if predecessors:
            sizes = [len(g.nodes[nn]['states']) # TODO: add recursive step for generic graphs
                     for nn in nx.ancestors(g, node)
                     if g.nodes[nn]['latent']]
            if sizes:
                with numpyro.plate_stack(f'plate[{node}]', sizes=sizes):
                    p[node] = sample(f'p[{node}]', dist.Dirichlet(hyperparameters[node]))
            else:
                p[node] = sample(f'p[{node}]', dist.Dirichlet(hyperparameters[node]))
            arr   = jnp.column_stack((data[predecessor]
                                      for predecessor
                                      in  predecessors))
            probs = jnp.apply_along_axis(lambda conditional: p[node][tuple(conditional)],
                                         axis=1,
                                         arr=arr)
        else:
            p[node] = sample(f'p[{node}]', dist.Dirichlet(hyperparameters[node]))
            probs   = jnp.ones( (n, len(g.nodes[node]['states'])) ) * p[node]
        
        with numpyro.plate('data', size=n):
            if node in data:
                sample(f'obs[{node}]', dist.Categorical(probs=probs), obs=data[node])
                nx.set_node_attributes(g, {node: False}, name='latent')
            else:
                data[node] = sample(f'latent[{node}]',
                                    dist.Categorical(probs=probs),
                                    infer={'enumerate': 'parallel'})
                nx.set_node_attributes(g, {node: True}, name='latent')

which yields (how do I read that message?):

from numpyro import handlers

trace = (numpyro
         .contrib
         .funsor
         .enum_messenger
         .trace(handlers.seed(model, simkeys[0]))
         .get_trace(graph=gg, data={'A': A, 'B': B, }, hyperparameters=priors, ))

print(numpyro.util.format_shapes(trace))

#  Trace Shapes:         
#   Param Sites:         
#  Sample Sites:         
#      p[Z] dist      | 2
#          value      | 2
# latent[Z] dist 5000 |  
#          value 5000 |  
#      p[A] dist    2 | 2
#          value    2 | 2
#    obs[A] dist 5000 |  
#          value 5000 |  
#      p[B] dist    2 | 2
#          value    2 | 2
#    obs[B] dist 5000 |  
#          value 5000 |  

The model runs perfectly fine if all data is observed:

kernel = NUTS(model)
mcmc   = MCMC(kernel, num_warmup=200, num_samples=300)

mcmc.run(random.PRNGKey(1), graph=gg, data={'A': A, 'B': B, 'Z': Z}, hyperparameters=priors)

But if 'Z' is not in the dictionary passed to the data argument, the following error is raised:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [20], in <cell line: 4>()
      1 kernel = NUTS(model)
      2 mcmc   = MCMC(kernel, num_warmup=200, num_samples=300)
----> 4 mcmc.run(random.PRNGKey(1), graph=gg, data={'A': A, 'B': B}, hyperparameters=priors)

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/mcmc.py:593, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    591 map_args = (rng_key, init_state, init_params)
    592 if self.num_chains == 1:
--> 593     states_flat, last_state = partial_map_fn(map_args)
    594     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    595 else:

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    379 rng_key, init_state, init_params = init
    380 if init_state is None:
--> 381     init_state = self.sampler.init(
    382         rng_key,
    383         self.num_warmup,
    384         init_params,
    385         model_args=args,
    386         model_kwargs=kwargs,
    387     )
    388 sample_fn, postprocess_fn = self._get_cached_fns()
    389 diagnostics = (
    390     lambda x: self.sampler.get_diagnostics_str(x[0])
    391     if rng_key.ndim == 1
    392     else ""
    393 )  # noqa: E731

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    701 # vectorized
    702 else:
    703     rng_key, rng_key_init_model = jnp.swapaxes(
    704         vmap(random.split)(rng_key), 0, 1
    705     )
--> 706 init_params = self._init_state(
    707     rng_key_init_model, model_args, model_kwargs, init_params
    708 )
    709 if self._potential_fn and init_params is None:
    710     raise ValueError(
    711         "Valid value of `init_params` must be provided with" " `potential_fn`."
    712     )

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    651     if self._model is not None:
--> 652         init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    653             rng_key,
    654             self._model,
    655             dynamic_args=True,
    656             init_strategy=self._init_strategy,
    657             model_args=model_args,
    658             model_kwargs=model_kwargs,
    659             forward_mode_differentiation=self._forward_mode_differentiation,
    660         )
    661         if self._init_fn is None:
    662             self._init_fn, self._sample_fn = hmc(
    663                 potential_fn_gen=potential_fn,
    664                 kinetic_fn=self._kinetic_fn,
    665                 algo=self._algo,
    666             )

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    652     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
    655     rng_key,
    656     substitute(
    657         model,
    658         data={
    659             k: site["value"]
    660             for k, site in model_trace.items()
    661             if site["type"] in ["plate"]
    662         },
    663     ),
    664     init_strategy=init_strategy,
    665     enum=has_enumerate_support,
    666     model_args=model_args,
    667     model_kwargs=model_kwargs,
    668     prototype_params=prototype_params,
    669     forward_mode_differentiation=forward_mode_differentiation,
    670     validate_grad=validate_grad,
    671 )
    673 if not_jax_tracer(is_valid):
    674     if device_get(~jnp.all(is_valid)):

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    393 # Handle possible vectorization
    394 if rng_key.ndim == 1:
--> 395     (init_params, pe, z_grad), is_valid = _find_valid_params(
    396         rng_key, exit_early=True
    397     )
    398 else:
    399     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False)
    378 if exit_early and not_jax_tracer(rng_key):
    379     # Early return if valid params found. This is only helpful for single chain,
    380     # where we can avoid compiling body_fn in while_loop.
--> 381     _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    382     if not_jax_tracer(is_valid):
    383         if device_get(is_valid):

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params.<locals>.body_fn(state)
    364     z_grad = jacfwd(potential_fn)(params)
    365 else:
--> 366     pe, z_grad = value_and_grad(potential_fn)(params)
    367 z_grad_flat = ravel_pytree(z_grad)[0]
    368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 8 frame]

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
    244 substituted_model = substitute(
    245     model, substitute_fn=partial(_unconstrain_reparam, params)
    246 )
    247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density_(
    249     substituted_model, model_args, model_kwargs, {}
    250 )
    251 return -log_joint

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:270, in log_density(model, model_args, model_kwargs, params)
    249 def log_density(model, model_args, model_kwargs, params):
    250     """
    251     Similar to :func:`numpyro.infer.util.log_density` but works for models
    252     with discrete latent variables. Internally, this uses :mod:`funsor`
   (...)
    268     :return: log of joint density and a corresponding model trace
    269     """
--> 270     result, model_trace, _ = _enum_log_density(
    271         model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    272     )
    273     return result.data, model_trace

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:159, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    157 model = substitute(model, data=params)
    158 with plate_to_enum_plate():
--> 159     model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    160 log_factors = []
    161 time_to_factors = defaultdict(list)  # log prob factors

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (4 times)]

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Input In [4], in model(graph, data, hyperparameters)
     34     p[node] = sample(f'p[{node}]', dist.Dirichlet(hyperparameters[node]))
     35     probs   = jnp.ones( (n, len(g.nodes[node]['states'])) ) * p[node]
---> 37 with numpyro.plate('data', size=n):
     38     if node in data:
     39         sample(f'obs[{node}]', dist.Categorical(probs=probs), obs=data[node])

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:506, in plate.__enter__(self)
    500 super().__enter__()  # do this first to take care of globals recycling
    501 name_to_dim = (
    502     OrderedDict([(self.name, self.dim)])
    503     if self.dim is not None
    504     else OrderedDict()
    505 )
--> 506 indices = to_data(
    507     self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE
    508 )
    509 # extract the dimension allocated by to_data to match plate's current behavior
    510 self.dim, self.indices = -len(indices.shape), indices.squeeze()

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:709, in to_data(x, name_to_dim, dim_type)
    696 name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim
    698 initial_msg = {
    699     "type": "to_data",
    700     "fn": lambda x, name_to_dim, dim_type: funsor.to_data(
   (...)
    706     "mask": None,
    707 }
--> 709 msg = apply_stack(initial_msg)
    710 return msg["value"]

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/primitives.py:59, in apply_stack(msg)
     55 # A Messenger that sets msg["stop"] == True also prevents application
     56 # of postprocess_message by Messengers above it on the stack
     57 # via the pointer variable from the process_message loop
     58 for handler in _PYRO_STACK[-pointer - 1 :]:
---> 59     handler.postprocess_message(msg)
     60 return msg

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:528, in plate.postprocess_message(self, msg)
    526 def postprocess_message(self, msg):
    527     if msg["type"] in ["to_funsor", "to_data"]:
--> 528         return super().postprocess_message(msg)
    529     # NB: copied literally from original plate messenger, with self._indices is replaced
    530     # by self.indices
    531     if msg["type"] in ("subsample", "param") and self.dim is not None:

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:417, in GlobalNamedMessenger.postprocess_message(self, msg)
    415     self._pyro_post_to_funsor(msg)
    416 elif msg["type"] == "to_data":
--> 417     self._pyro_post_to_data(msg)

File ~/environments/fusion/lib/python3.8/site-packages/numpyro/contrib/funsor/enum_messenger.py:430, in GlobalNamedMessenger._pyro_post_to_data(self, msg)
    427 if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
    428     for name in msg["args"][0].inputs:
    429         self._saved_globals += (
--> 430             (name, _DIM_STACK.global_frame.name_to_dim[name]),
    431         )

KeyError: 'data'