Interfacing with numpyro

Hello,

I have been working on a jax package for equinox-based distributions and normalising flows (GitHub - danielward27/flowjax). I am looking to be able to wrap the distributions to be compatible with numpyro to use them in MCMC and variational inference.

So far I have the following draft which seems to work for unconditional distributions:

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

import flowjax
from flowjax.utils import _get_ufunc_signature  # used to vectorize bijection methods

import numpyro

from numpyro.distributions import constraints


class FlowjaxToNumpyroDistribution(numpyro.distributions.Distribution):
    "Wraps a flowjax distribution instance into a numpyro distribution."

    def __init__(
        self,
        distribution: flowjax.distributions.Distribution,
        support: constraints.Constraint = constraints.real,
    ):
        self.distribution = distribution
        self.support = support
        super().__init__(batch_shape=(), event_shape=distribution.shape)

    def log_prob(self, value, intermediates=None):
        if intermediates is None:
            return self.distribution.log_prob(value)
        else:
            if not isinstance(self.distribution, flowjax.distributions.Transformed):
                raise ValueError(
                    "Intermediates can only be used with transformed distributions."
                )
            log_prob_base = self.distribution.base_dist.log_prob(intermediates[0])

            # Vectorize bijection
            signature = _get_ufunc_signature(
                [self.distribution.shape], [self.distribution.shape, ()]
            )
            _, forward_log_det = jnp.vectorize(
                self.distribution.bijection.transform_and_log_det, signature=signature
            )(intermediates[0])
            return log_prob_base - forward_log_det

    def sample(self, key, sample_shape=...):
        return self.distribution.sample(key, sample_shape)

    def register_params(self, filter_spec=eqx.is_inexact_array):
        """Register numpyro params specified by filter_spec using jax key path strings as
        names, see jax.tree_util.keystr. This, like ``numpyro.param``, needs to be
        called from an inference context to have an effect."""
        params, static = eqx.partition(self.distribution, filter_spec)
        params = jax.tree_util.tree_map_with_path(
            lambda key_path, x: numpyro.param(jax.tree_util.keystr(key_path), x),
            params,
        )
        self.distribution = eqx.combine(params, static)
        return self

    def sample_with_intermediates(self, key, sample_shape=...):
        if isinstance(self.distribution, flowjax.distributions.Transformed):
            base_samples = self.distribution.base_dist.sample(key, sample_shape)
            signature = _get_ufunc_signature(
                [self.distribution.shape], [self.distribution.shape]
            )
            transformed_samples = jnp.vectorize(
                self.distribution.bijection.transform, signature=signature
            )(base_samples)

            return transformed_samples, [base_samples]
        else:
            return self.sample(key, sample_shape), []


def update_parameters_from_key_path_string_dict(
    dist, param_dict, filter_spec=eqx.is_inexact_array
):
    """Given a parameter dictionary from numpyro update the distribution
    parameters."""
    params, static = eqx.partition(dist, filter_spec)

    def update_fn(key_path, x):
        p = param_dict[jax.tree_util.keystr(key_path)]
        if p.shape != x.shape:
            raise ValueError(f"Expected parameter shape {x.shape}, got {p.shape}")
        return p

    params = jax.tree_util.tree_map_with_path(update_fn, params)
    return eqx.combine(params, static)


if __name__ == "__main__":
    from functools import partial
    import pytest
    from numpyro import sample
    from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO

    true_mean = jnp.ones(2)
    true_std = 2 * jnp.ones(2)

    def numpyro_model():
        sample(
            "x",
            FlowjaxToNumpyroDistribution(
                flowjax.distributions.Normal(true_mean, true_std)
            ),
        )

    # MCMC test
    key = jr.PRNGKey(0)
    mcmc = MCMC(NUTS(numpyro_model), num_warmup=100, num_samples=2000)  # 2d N(1, 2)
    key, subkey = jr.split(key)
    mcmc.run(subkey)

    samps = mcmc.get_samples()["x"]
    assert pytest.approx(samps.mean(axis=0), abs=0.1) == true_mean
    assert pytest.approx(samps.std(axis=0), abs=0.1) == true_std
    print("MCMC test passed")

    # VI test
    def guide(dist):
        dist = FlowjaxToNumpyroDistribution(dist)
        dist.register_params(eqx.is_inexact_array)
        sample("x", dist)

    optimizer = numpyro.optim.Adam(step_size=0.001)

    guide_dist = flowjax.distributions.Normal(jnp.zeros(2), 1)  # 2d N(0, 1)

    svi = SVI(numpyro_model, partial(guide, guide_dist), optimizer, loss=Trace_ELBO())
    svi_result = svi.run(jr.PRNGKey(0), num_steps=10000)

    flowjax_dist = update_parameters_from_key_path_string_dict(
        guide_dist, svi_result.params
    )
    assert pytest.approx(flowjax_dist.loc, abs=0.1) == true_mean
    assert pytest.approx(flowjax_dist.scale, abs=0.1) == true_std
    print("VI test passed")

This seems to work (the assert statements pass). However, I would appreciate some advice on the following:

  1. Does the overall approach look reasonable? The register_params method seems very unjax-like, mutating the distribution, maybe there is a better way? Because of the complexity of the distributions (flows), I do not want to manually name the parameters.
  2. The sample_with_intermediates and log_prob with intermediates that I have implemented currently seems problematic. In flowjax, I have a sample_and_log_prob method for all distributions, and overwrite it to the generally more efficient approach for transformed distributions, avoiding the inverse computation (https://github.com/danielward27/flowjax/blob/ed600296bfcb29993c8a2e15d1360be93f17a442/flowjax/distributions.py#L68). This approach naturally can propogate through nested transformed distributions (i.e. transformed transformed distributions), as each call to sample_and_log_prob internally calls sample_and_log_prob on the (possibly transformed) base distribution. What would be equivalent approach here? Have sample_with_intermediates return a list with length of the number of nested Transformed distributions, which then can be indexed and used in log_prob with intermediates?

Thanks for any input!

Ah I see, in numpyro nested applications of transformed just concatenates the bijections so this issue wouldn’t arise: https://github.com/pyro-ppl/numpyro/blob/bcf38d7e2ca97fded876cf034866eaa0f00385ff/numpyro/distributions/distribution.py#L974C8-L976C72

I can consider doing something similar.

edit: after a bit more thought, even with the above approach of enforcing no nested transforms, this is a bit tricky. e.g. what if a numpyro transform is applied to my custom distribution? It won’t be recognised as a transformed distribution, and I think the same issue arises?

So because of my concern about how reparameterisations / sample_with_intermediates for flowjax transformed distributions would function if transformed by a numpyro transform, I instead made a wrapper for the flowjax bijections to convert them into numypyro bijections so I can use them directly within numpyros TransformedDistribution, so reparameterisation should be handled better. I think I am slowly getting there, here is the updated draft

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

import flowjax

import numpyro
from numpyro.distributions import constraints


class BijectionToNumpyro(numpyro.distributions.transforms.Transform):
    def __init__(
        self,
        bijection: flowjax.bijections.Bijection,
        domain: constraints.Constraint = constraints.real,
    ):
        self.bijection = bijection
        self.domain = domain

    def __call__(self, x):
        return self.bijection.transform(x)

    def _inverse(self, y):
        return self.bijection.inverse(y)

    def log_abs_det_jacobian(self, x, y, intermediates=None):
        return self.bijection.transform_and_log_det(x)[1]

    def tree_flatten(self):
        raise NotImplementedError()


def distribution_to_numpyro(dist: flowjax.distributions.Distribution):
    if isinstance(dist, flowjax.distributions.Transformed):
        return _TransformedDistributionToNumpyro(dist)
    elif isinstance(dist, flowjax.distributions.Distribution):
        return _DistributionToNumpyro(dist)
    else:
        raise ValueError(f"Expected flowjax distribution, got {type(dist)}")


class _DistributionToNumpyro(numpyro.distributions.Distribution):
    """Wraps a flowjax distribution instance into a numpyro distribution."""

    def __init__(
        self,
        distribution: flowjax.distributions.Distribution,
        support: constraints.Constraint = constraints.real,
    ):
        """
        Args:
            distribution (flowjax.distributions.Distribution): A flowjax distribution
                that is not transformed.
            support (constraints.Constraint, optional): Numpyro constraint.
                Defaults to constraints.real.
        """
        if isinstance(distribution, flowjax.distributions.Transformed):
            raise ValueError("Use FlowjaxtoNumpyroTransformedDistribution")
        self.distribution = distribution
        self.support = support
        super().__init__(batch_shape=(), event_shape=distribution.shape)

    def log_prob(self, value):
        return self.distribution.log_prob(value)

    def sample(self, key, sample_shape=...):
        return self.distribution.sample(key, sample_shape)


class _TransformedDistributionToNumpyro(numpyro.distributions.TransformedDistribution):
    """To create the transformed distribution, we extract the flowjax transforms and
    base distribution, wrap them to numpyro equivalents then create the distribution."""

    def __init__(self, dist: flowjax.distributions.Transformed):
        # TODO allow control of support/domains?
        dist = dist.unnest_transforms()  # Ensure base distribution is not transformed
        base_dist = _DistributionToNumpyro(dist.base_dist)
        transforms = [BijectionToNumpyro(dist.bijection)]
        super().__init__(base_distribution=base_dist, transforms=transforms)


def register_params(model: eqx.Module, filter_spec=eqx.is_inexact_array):
    """Register numpyro params specified by filter_spec using jax key path strings as
    names, see jax.tree_util.keystr. This, like ``numpyro.param``, needs to be
    called from an inference context to have an effect."""
    params, static = eqx.partition(model, filter_spec)
    params = jax.tree_util.tree_map_with_path(
        lambda key_path, x: numpyro.param(jax.tree_util.keystr(key_path), x),
        params,
    )
    model = eqx.combine(params, static)
    return model


def update_parameters_from_key_path_string_dict(
    model, param_dict, filter_spec=eqx.is_inexact_array
):
    """Given a parameter dictionary from numpyro update the distribution
    parameters."""
    params, static = eqx.partition(model, filter_spec)

    def update_fn(key_path, x):
        p = param_dict[jax.tree_util.keystr(key_path)]
        if p.shape != x.shape:
            raise ValueError(f"Expected parameter shape {x.shape}, got {p.shape}")
        return p

    params = jax.tree_util.tree_map_with_path(update_fn, params)
    return eqx.combine(params, static)


if __name__ == "__main__":
    # Run e.g. with python -m flowjax.numpyro from project root directory
    from functools import partial

    import pytest
    from numpyro import sample
    from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO

    true_mean = jnp.ones(2)
    true_std = 2 * jnp.ones(2)

    def numpyro_model():
        sample(
            "x",
            distribution_to_numpyro(flowjax.distributions.Normal(true_mean, true_std)),
        )

    # MCMC test
    key = jr.PRNGKey(0)
    mcmc = MCMC(NUTS(numpyro_model), num_warmup=100, num_samples=2000)  # 2d N(1, 2)
    key, subkey = jr.split(key)
    mcmc.run(subkey)

    samps = mcmc.get_samples()["x"]
    assert pytest.approx(samps.mean(axis=0), abs=0.1) == true_mean
    assert pytest.approx(samps.std(axis=0), abs=0.1) == true_std
    print("MCMC test passed")

    # VI test
    def guide(dist):
        dist = register_params(dist)
        dist = distribution_to_numpyro(dist)
        sample("x", dist)

    optimizer = numpyro.optim.Adam(step_size=0.001)

    guide_dist = flowjax.distributions.Normal(jnp.zeros(2), 1)  # 2d N(0, 1)

    svi = SVI(numpyro_model, partial(guide, guide_dist), optimizer, loss=Trace_ELBO())
    svi_result = svi.run(jr.PRNGKey(0), num_steps=10000)

    flowjax_dist = update_parameters_from_key_path_string_dict(
        guide_dist, svi_result.params
    )
    assert pytest.approx(flowjax_dist.loc, abs=0.1) == true_mean
    assert pytest.approx(flowjax_dist.scale, abs=0.1) == true_std
    print("VI test passed")

This is allowing me to use flows without an analytic inverse as the variational distribution, so I presume this must mean internally numpyro is correctly utilising the intermediates for computing the base distribution log probability

With regards to registering parameters, it seems I can fortunately just do

params, static = eqx.partition(model, filter_spec)
params = numpyro.param(name, params)
model = eqx.combine(params, static)

the documentation for param says init_value should be jnp.ndarray or callable, but a PyTree with jnp.ndarray leaves seems to work too