Enumerating over lots of variables for Bayesian network

Hi there! I’ve been working on using numpyro to implement a Bayesian network (i.e., a directed graphical model) while inferring a posterior over the CPT parameters, where the training data has some missingness (so that some variables in each data sample are not observed).

I have been able to get a working implementation for small models, but encounter an error where I run out of dimensions when applying it to larger models. The issue seems to be that numpyro doesn’t know enough about the conditional independencies in my model to be able to safely recycle the dimensions - is there a way I can mark these conditional independencies more explicitly so that the dimensions can be recycled and avoid this error?

I’ve put an example below to show what I’ve done currently. This constructs a network with a chain structure node_1 -> node_2 ... node_n as a toy model because it’s an easy way to parametrise the size. However, I’d like to apply this to more general models though (specified in the structure dict below) without this 1-child/1-parent chain structure, so a numpyro equivalent to pyro’s markov like here probably wouldn’t apply.

First we set up the network structure:

import numpyro
from numpyro import distributions as dist, sample
from numpyro.infer.mcmc import MCMC
from numpyro.infer import NUTS
from jax import numpy as jnp, random

import numpy as onp


# Parametrise the total size of the model (i.e., the length of the chain)
n_nodes = 10
n_data = 10000


# Define a toy chain structure because it's easy to parametrise the total size
structure = {
    f'node_{i+1}': {'n_states': 2, 'parents': [f'node_{i}'], 'manual_cpt': None}
    for i in range(n_nodes)
}
structure['node_1']['parents'] = []


# Get the shapes for the arrays storing the conditional probabilities for each node
cpt_shapes = {
    name: tuple(
        [node['n_states']] + [
            structure[n]['n_states'] for n in node['parents']
        ]
    ) for name, node in structure.items()
}


# Just generate some random data + get whether it's observed
data_dict = {
    f'node_{i+1}': jnp.array(onp.random.choice(
        [0., 1., jnp.nan], size=n_data, p=onp.random.dirichlet([2., 2., 1])
    ))
    for i in range(n_nodes)
}
obs_mask = {node: jnp.isfinite(data_dict[node]) for node in data_dict}

Then define the model:

def model(data, obs_mask, n_rows):
    probs = {}
    # Priors
    for node in data:
        with numpyro.plate_stack(f'plate_{node}', cpt_shapes[node][:-1]):
            probs[node] = sample(f'p_{node}', dist.Dirichlet(jnp.ones(cpt_shapes[node])))

    # Likelihoods
    completed = {}
    with numpyro.plate('obs', n_rows):
        for node in data:
            parent_slice = tuple([completed[n].astype(int) for n in structure[node]['parents']])
            node_probs = probs[node][parent_slice]
            completed[node] = sample(
                node,
                dist.Categorical(probs=node_probs),
                obs=data[node],
                obs_mask=obs_mask[node],
            ) 

If I do MCMC on the above model when n_nodes = 10, everything works as expected. Increasing the size of the model (e.g., n_nodes = 100), I get the following error:

ValueError: Ran out of free dims during allocation for node_25_unobserved

Thanks for your help!

Hi @danmichaeljones,

I don’t know the exact solution to your problem, but you may find the discussion from Dealing with large numbers of variables: (re-)introducing pyro.markov informative.

I am hedging here but, markov from Funsor-based NumPyro may also be useful, but I haven’t used it personally and I could be wrong.

Good luck. =)

Edit: wording