Hello,
I have been working on a jax package for equinox-based distributions and normalising flows (GitHub - danielward27/flowjax). I am looking to be able to wrap the distributions to be compatible with numpyro to use them in MCMC and variational inference.
So far I have the following draft which seems to work for unconditional distributions:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import flowjax
from flowjax.utils import _get_ufunc_signature # used to vectorize bijection methods
import numpyro
from numpyro.distributions import constraints
class FlowjaxToNumpyroDistribution(numpyro.distributions.Distribution):
"Wraps a flowjax distribution instance into a numpyro distribution."
def __init__(
self,
distribution: flowjax.distributions.Distribution,
support: constraints.Constraint = constraints.real,
):
self.distribution = distribution
self.support = support
super().__init__(batch_shape=(), event_shape=distribution.shape)
def log_prob(self, value, intermediates=None):
if intermediates is None:
return self.distribution.log_prob(value)
else:
if not isinstance(self.distribution, flowjax.distributions.Transformed):
raise ValueError(
"Intermediates can only be used with transformed distributions."
)
log_prob_base = self.distribution.base_dist.log_prob(intermediates[0])
# Vectorize bijection
signature = _get_ufunc_signature(
[self.distribution.shape], [self.distribution.shape, ()]
)
_, forward_log_det = jnp.vectorize(
self.distribution.bijection.transform_and_log_det, signature=signature
)(intermediates[0])
return log_prob_base - forward_log_det
def sample(self, key, sample_shape=...):
return self.distribution.sample(key, sample_shape)
def register_params(self, filter_spec=eqx.is_inexact_array):
"""Register numpyro params specified by filter_spec using jax key path strings as
names, see jax.tree_util.keystr. This, like ``numpyro.param``, needs to be
called from an inference context to have an effect."""
params, static = eqx.partition(self.distribution, filter_spec)
params = jax.tree_util.tree_map_with_path(
lambda key_path, x: numpyro.param(jax.tree_util.keystr(key_path), x),
params,
)
self.distribution = eqx.combine(params, static)
return self
def sample_with_intermediates(self, key, sample_shape=...):
if isinstance(self.distribution, flowjax.distributions.Transformed):
base_samples = self.distribution.base_dist.sample(key, sample_shape)
signature = _get_ufunc_signature(
[self.distribution.shape], [self.distribution.shape]
)
transformed_samples = jnp.vectorize(
self.distribution.bijection.transform, signature=signature
)(base_samples)
return transformed_samples, [base_samples]
else:
return self.sample(key, sample_shape), []
def update_parameters_from_key_path_string_dict(
dist, param_dict, filter_spec=eqx.is_inexact_array
):
"""Given a parameter dictionary from numpyro update the distribution
parameters."""
params, static = eqx.partition(dist, filter_spec)
def update_fn(key_path, x):
p = param_dict[jax.tree_util.keystr(key_path)]
if p.shape != x.shape:
raise ValueError(f"Expected parameter shape {x.shape}, got {p.shape}")
return p
params = jax.tree_util.tree_map_with_path(update_fn, params)
return eqx.combine(params, static)
if __name__ == "__main__":
from functools import partial
import pytest
from numpyro import sample
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
true_mean = jnp.ones(2)
true_std = 2 * jnp.ones(2)
def numpyro_model():
sample(
"x",
FlowjaxToNumpyroDistribution(
flowjax.distributions.Normal(true_mean, true_std)
),
)
# MCMC test
key = jr.PRNGKey(0)
mcmc = MCMC(NUTS(numpyro_model), num_warmup=100, num_samples=2000) # 2d N(1, 2)
key, subkey = jr.split(key)
mcmc.run(subkey)
samps = mcmc.get_samples()["x"]
assert pytest.approx(samps.mean(axis=0), abs=0.1) == true_mean
assert pytest.approx(samps.std(axis=0), abs=0.1) == true_std
print("MCMC test passed")
# VI test
def guide(dist):
dist = FlowjaxToNumpyroDistribution(dist)
dist.register_params(eqx.is_inexact_array)
sample("x", dist)
optimizer = numpyro.optim.Adam(step_size=0.001)
guide_dist = flowjax.distributions.Normal(jnp.zeros(2), 1) # 2d N(0, 1)
svi = SVI(numpyro_model, partial(guide, guide_dist), optimizer, loss=Trace_ELBO())
svi_result = svi.run(jr.PRNGKey(0), num_steps=10000)
flowjax_dist = update_parameters_from_key_path_string_dict(
guide_dist, svi_result.params
)
assert pytest.approx(flowjax_dist.loc, abs=0.1) == true_mean
assert pytest.approx(flowjax_dist.scale, abs=0.1) == true_std
print("VI test passed")
This seems to work (the assert statements pass). However, I would appreciate some advice on the following:
- Does the overall approach look reasonable? The
register_params
method seems very unjax-like, mutating the distribution, maybe there is a better way? Because of the complexity of the distributions (flows), I do not want to manually name the parameters. - The
sample_with_intermediates
andlog_prob
with intermediates that I have implemented currently seems problematic. In flowjax, I have asample_and_log_prob
method for all distributions, and overwrite it to the generally more efficient approach for transformed distributions, avoiding the inverse computation (https://github.com/danielward27/flowjax/blob/ed600296bfcb29993c8a2e15d1360be93f17a442/flowjax/distributions.py#L68). This approach naturally can propogate through nested transformed distributions (i.e. transformed transformed distributions), as each call tosample_and_log_prob
internally callssample_and_log_prob
on the (possibly transformed) base distribution. What would be equivalent approach here? Havesample_with_intermediates
return a list with length of the number of nestedTransformed
distributions, which then can be indexed and used inlog_prob
with intermediates?
Thanks for any input!