Is there a way to use a pre-trained PyToch model within a Numpyro model and what is the right way to handle type conversions? The parameters of the PyToch model are fixed and do not need to be inferred; their aim is to define a deterministic transformation.
(To answer your questions on why I am not going down the route of PyToch + Pyro or JAX+Numpyro: HMC with PyToch + Pyro is increadibly slow; JAX I found hard to work with, plus I need access to the PyTorch-based ecosystem)
The minimal example:
import torch
import torch.nn as nn
import jax.numpy as jnp
from jax import random
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
# PyTorch model
model = nn.Linear(4, 4)
# Numpyro model
def numpyro_model(z_dim=4):
z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim)))
z_np = np.asarray(z)
z_torch = torch.from_numpy(z_np)
# deterministic transformation, defined by the PyTorch model, applied to a random variable, defined within the Numpyro model
with torch.no_grad():
f = numpyro.deterministic("f", model(z_torch))
return f
# Inference
kernel = NUTS(numpyro_model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(rng_key=random.PRNGKey(2), z_dim=4)
The resulting error:
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-32-adbd57b6da87> in <module>
30 mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
---> 31 mcmc.run(rng_key=random.PRNGKey(2), z_dim=4)
43 frames
UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([-0.71174574 -1.1693654 0.19383144 -0.91461945], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = DeviceArray([-0.71174574, -1.1693654 , 0.19383144, -0.91461945], dtype=float32)
tangent = Traced<ShapedArray(float32[4])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[4]), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TracerArrayConversionError Traceback (most recent call last)
<ipython-input-32-adbd57b6da87> in numpyro_model(z_dim)
18
19 z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim)))
---> 20 z_np = np.asarray(z)
21 z_torch = torch.from_numpy(z_np)
22
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([-0.71174574 -1.1693654 0.19383144 -0.91461945], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = DeviceArray([-0.71174574, -1.1693654 , 0.19383144, -0.91461945], dtype=float32)
tangent = Traced<ShapedArray(float32[4])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[4]), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError