Interfacing with numpyro

Hello,

I have been working on a jax package for equinox-based distributions and normalising flows (GitHub - danielward27/flowjax). I am looking to be able to wrap the distributions to be compatible with numpyro to use them in MCMC and variational inference.

So far I have the following draft which seems to work for unconditional distributions:

``````import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

import flowjax
from flowjax.utils import _get_ufunc_signature  # used to vectorize bijection methods

import numpyro

from numpyro.distributions import constraints

class FlowjaxToNumpyroDistribution(numpyro.distributions.Distribution):
"Wraps a flowjax distribution instance into a numpyro distribution."

def __init__(
self,
distribution: flowjax.distributions.Distribution,
support: constraints.Constraint = constraints.real,
):
self.distribution = distribution
self.support = support
super().__init__(batch_shape=(), event_shape=distribution.shape)

def log_prob(self, value, intermediates=None):
if intermediates is None:
return self.distribution.log_prob(value)
else:
if not isinstance(self.distribution, flowjax.distributions.Transformed):
raise ValueError(
"Intermediates can only be used with transformed distributions."
)
log_prob_base = self.distribution.base_dist.log_prob(intermediates[0])

# Vectorize bijection
signature = _get_ufunc_signature(
[self.distribution.shape], [self.distribution.shape, ()]
)
_, forward_log_det = jnp.vectorize(
self.distribution.bijection.transform_and_log_det, signature=signature
)(intermediates[0])
return log_prob_base - forward_log_det

def sample(self, key, sample_shape=...):
return self.distribution.sample(key, sample_shape)

def register_params(self, filter_spec=eqx.is_inexact_array):
"""Register numpyro params specified by filter_spec using jax key path strings as
names, see jax.tree_util.keystr. This, like ``numpyro.param``, needs to be
called from an inference context to have an effect."""
params, static = eqx.partition(self.distribution, filter_spec)
params = jax.tree_util.tree_map_with_path(
lambda key_path, x: numpyro.param(jax.tree_util.keystr(key_path), x),
params,
)
self.distribution = eqx.combine(params, static)
return self

def sample_with_intermediates(self, key, sample_shape=...):
if isinstance(self.distribution, flowjax.distributions.Transformed):
base_samples = self.distribution.base_dist.sample(key, sample_shape)
signature = _get_ufunc_signature(
[self.distribution.shape], [self.distribution.shape]
)
transformed_samples = jnp.vectorize(
self.distribution.bijection.transform, signature=signature
)(base_samples)

return transformed_samples, [base_samples]
else:
return self.sample(key, sample_shape), []

def update_parameters_from_key_path_string_dict(
dist, param_dict, filter_spec=eqx.is_inexact_array
):
"""Given a parameter dictionary from numpyro update the distribution
parameters."""
params, static = eqx.partition(dist, filter_spec)

def update_fn(key_path, x):
p = param_dict[jax.tree_util.keystr(key_path)]
if p.shape != x.shape:
raise ValueError(f"Expected parameter shape {x.shape}, got {p.shape}")
return p

params = jax.tree_util.tree_map_with_path(update_fn, params)
return eqx.combine(params, static)

if __name__ == "__main__":
from functools import partial
import pytest
from numpyro import sample
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO

true_mean = jnp.ones(2)
true_std = 2 * jnp.ones(2)

def numpyro_model():
sample(
"x",
FlowjaxToNumpyroDistribution(
flowjax.distributions.Normal(true_mean, true_std)
),
)

# MCMC test
key = jr.PRNGKey(0)
mcmc = MCMC(NUTS(numpyro_model), num_warmup=100, num_samples=2000)  # 2d N(1, 2)
key, subkey = jr.split(key)
mcmc.run(subkey)

samps = mcmc.get_samples()["x"]
assert pytest.approx(samps.mean(axis=0), abs=0.1) == true_mean
assert pytest.approx(samps.std(axis=0), abs=0.1) == true_std
print("MCMC test passed")

# VI test
def guide(dist):
dist = FlowjaxToNumpyroDistribution(dist)
dist.register_params(eqx.is_inexact_array)
sample("x", dist)

guide_dist = flowjax.distributions.Normal(jnp.zeros(2), 1)  # 2d N(0, 1)

svi = SVI(numpyro_model, partial(guide, guide_dist), optimizer, loss=Trace_ELBO())
svi_result = svi.run(jr.PRNGKey(0), num_steps=10000)

flowjax_dist = update_parameters_from_key_path_string_dict(
guide_dist, svi_result.params
)
assert pytest.approx(flowjax_dist.loc, abs=0.1) == true_mean
assert pytest.approx(flowjax_dist.scale, abs=0.1) == true_std
print("VI test passed")
``````

This seems to work (the assert statements pass). However, I would appreciate some advice on the following:

1. Does the overall approach look reasonable? The `register_params` method seems very unjax-like, mutating the distribution, maybe there is a better way? Because of the complexity of the distributions (flows), I do not want to manually name the parameters.
2. The `sample_with_intermediates` and `log_prob` with intermediates that I have implemented currently seems problematic. In flowjax, I have a `sample_and_log_prob` method for all distributions, and overwrite it to the generally more efficient approach for transformed distributions, avoiding the inverse computation (https://github.com/danielward27/flowjax/blob/ed600296bfcb29993c8a2e15d1360be93f17a442/flowjax/distributions.py#L68). This approach naturally can propogate through nested transformed distributions (i.e. transformed transformed distributions), as each call to `sample_and_log_prob` internally calls `sample_and_log_prob` on the (possibly transformed) base distribution. What would be equivalent approach here? Have `sample_with_intermediates` return a list with length of the number of nested `Transformed` distributions, which then can be indexed and used in `log_prob` with intermediates?

Thanks for any input!

Ah I see, in numpyro nested applications of transformed just concatenates the bijections so this issue wouldn’t arise: https://github.com/pyro-ppl/numpyro/blob/bcf38d7e2ca97fded876cf034866eaa0f00385ff/numpyro/distributions/distribution.py#L974C8-L976C72

I can consider doing something similar.

edit: after a bit more thought, even with the above approach of enforcing no nested transforms, this is a bit tricky. e.g. what if a numpyro transform is applied to my custom distribution? It won’t be recognised as a transformed distribution, and I think the same issue arises?

So because of my concern about how reparameterisations / `sample_with_intermediates` for flowjax transformed distributions would function if transformed by a numpyro transform, I instead made a wrapper for the flowjax bijections to convert them into numypyro bijections so I can use them directly within numpyros TransformedDistribution, so reparameterisation should be handled better. I think I am slowly getting there, here is the updated draft

``````import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

import flowjax

import numpyro
from numpyro.distributions import constraints

class BijectionToNumpyro(numpyro.distributions.transforms.Transform):
def __init__(
self,
bijection: flowjax.bijections.Bijection,
domain: constraints.Constraint = constraints.real,
):
self.bijection = bijection
self.domain = domain

def __call__(self, x):
return self.bijection.transform(x)

def _inverse(self, y):
return self.bijection.inverse(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
return self.bijection.transform_and_log_det(x)[1]

def tree_flatten(self):
raise NotImplementedError()

def distribution_to_numpyro(dist: flowjax.distributions.Distribution):
if isinstance(dist, flowjax.distributions.Transformed):
return _TransformedDistributionToNumpyro(dist)
elif isinstance(dist, flowjax.distributions.Distribution):
return _DistributionToNumpyro(dist)
else:
raise ValueError(f"Expected flowjax distribution, got {type(dist)}")

class _DistributionToNumpyro(numpyro.distributions.Distribution):
"""Wraps a flowjax distribution instance into a numpyro distribution."""

def __init__(
self,
distribution: flowjax.distributions.Distribution,
support: constraints.Constraint = constraints.real,
):
"""
Args:
distribution (flowjax.distributions.Distribution): A flowjax distribution
that is not transformed.
support (constraints.Constraint, optional): Numpyro constraint.
Defaults to constraints.real.
"""
if isinstance(distribution, flowjax.distributions.Transformed):
raise ValueError("Use FlowjaxtoNumpyroTransformedDistribution")
self.distribution = distribution
self.support = support
super().__init__(batch_shape=(), event_shape=distribution.shape)

def log_prob(self, value):
return self.distribution.log_prob(value)

def sample(self, key, sample_shape=...):
return self.distribution.sample(key, sample_shape)

class _TransformedDistributionToNumpyro(numpyro.distributions.TransformedDistribution):
"""To create the transformed distribution, we extract the flowjax transforms and
base distribution, wrap them to numpyro equivalents then create the distribution."""

def __init__(self, dist: flowjax.distributions.Transformed):
# TODO allow control of support/domains?
dist = dist.unnest_transforms()  # Ensure base distribution is not transformed
base_dist = _DistributionToNumpyro(dist.base_dist)
transforms = [BijectionToNumpyro(dist.bijection)]
super().__init__(base_distribution=base_dist, transforms=transforms)

def register_params(model: eqx.Module, filter_spec=eqx.is_inexact_array):
"""Register numpyro params specified by filter_spec using jax key path strings as
names, see jax.tree_util.keystr. This, like ``numpyro.param``, needs to be
called from an inference context to have an effect."""
params, static = eqx.partition(model, filter_spec)
params = jax.tree_util.tree_map_with_path(
lambda key_path, x: numpyro.param(jax.tree_util.keystr(key_path), x),
params,
)
model = eqx.combine(params, static)
return model

def update_parameters_from_key_path_string_dict(
model, param_dict, filter_spec=eqx.is_inexact_array
):
"""Given a parameter dictionary from numpyro update the distribution
parameters."""
params, static = eqx.partition(model, filter_spec)

def update_fn(key_path, x):
p = param_dict[jax.tree_util.keystr(key_path)]
if p.shape != x.shape:
raise ValueError(f"Expected parameter shape {x.shape}, got {p.shape}")
return p

params = jax.tree_util.tree_map_with_path(update_fn, params)
return eqx.combine(params, static)

if __name__ == "__main__":
# Run e.g. with python -m flowjax.numpyro from project root directory
from functools import partial

import pytest
from numpyro import sample
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO

true_mean = jnp.ones(2)
true_std = 2 * jnp.ones(2)

def numpyro_model():
sample(
"x",
distribution_to_numpyro(flowjax.distributions.Normal(true_mean, true_std)),
)

# MCMC test
key = jr.PRNGKey(0)
mcmc = MCMC(NUTS(numpyro_model), num_warmup=100, num_samples=2000)  # 2d N(1, 2)
key, subkey = jr.split(key)
mcmc.run(subkey)

samps = mcmc.get_samples()["x"]
assert pytest.approx(samps.mean(axis=0), abs=0.1) == true_mean
assert pytest.approx(samps.std(axis=0), abs=0.1) == true_std
print("MCMC test passed")

# VI test
def guide(dist):
dist = register_params(dist)
dist = distribution_to_numpyro(dist)
sample("x", dist)

guide_dist = flowjax.distributions.Normal(jnp.zeros(2), 1)  # 2d N(0, 1)

svi = SVI(numpyro_model, partial(guide, guide_dist), optimizer, loss=Trace_ELBO())
svi_result = svi.run(jr.PRNGKey(0), num_steps=10000)

flowjax_dist = update_parameters_from_key_path_string_dict(
guide_dist, svi_result.params
)
assert pytest.approx(flowjax_dist.loc, abs=0.1) == true_mean
assert pytest.approx(flowjax_dist.scale, abs=0.1) == true_std
print("VI test passed")
``````

This is allowing me to use flows without an analytic inverse as the variational distribution, so I presume this must mean internally numpyro is correctly utilising the intermediates for computing the base distribution log probability

With regards to registering parameters, it seems I can fortunately just do

``````params, static = eqx.partition(model, filter_spec)
params = numpyro.param(name, params)
model = eqx.combine(params, static)
``````

the documentation for param says init_value should be jnp.ndarray or callable, but a PyTree with jnp.ndarray leaves seems to work too