How to use a pre-trained PyToch model within Numpyro?

Is there a way to use a pre-trained PyToch model within a Numpyro model and what is the right way to handle type conversions? The parameters of the PyToch model are fixed and do not need to be inferred; their aim is to define a deterministic transformation.

(To answer your questions on why I am not going down the route of PyToch + Pyro or JAX+Numpyro: HMC with PyToch + Pyro is increadibly slow; JAX I found hard to work with, plus I need access to the PyTorch-based ecosystem)

The minimal example:

import torch
import torch.nn as nn

import jax.numpy as jnp
from jax import random

import numpy as np

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

# PyTorch model
model = nn.Linear(4, 4)

# Numpyro model
def numpyro_model(z_dim=4):

    z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim)))
    z_np = np.asarray(z)
    z_torch = torch.from_numpy(z_np)

    # deterministic transformation, defined by the PyTorch model, applied to a random variable, defined within the Numpyro model
    with torch.no_grad():
        f = numpyro.deterministic("f", model(z_torch))

    return f

# Inference
kernel = NUTS(numpyro_model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(rng_key=random.PRNGKey(2), z_dim=4)

The resulting error:

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-32-adbd57b6da87> in <module>
     30 mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
---> 31 mcmc.run(rng_key=random.PRNGKey(2), z_dim=4)

43 frames
UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([-0.71174574 -1.1693654   0.19383144 -0.91461945], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray([-0.71174574, -1.1693654 ,  0.19383144, -0.91461945], dtype=float32)
  tangent = Traced<ShapedArray(float32[4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[4]), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
<ipython-input-32-adbd57b6da87> in numpyro_model(z_dim)
     18 
     19     z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim)))
---> 20     z_np = np.asarray(z)
     21     z_torch = torch.from_numpy(z_np)
     22 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([-0.71174574 -1.1693654   0.19383144 -0.91461945], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray([-0.71174574, -1.1693654 ,  0.19383144, -0.91461945], dtype=float32)
  tangent = Traced<ShapedArray(float32[4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[4]), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

If you don’t need grad then you can use host_callback.

But I feel that you can just simply do inference without pytorch stuff, collect samples, then transform samples using your pytorch model.

@fehiepsi Thanks for looking into this. I have followed your idea and did not manage to resolve the issue yet - see below.

But I feel that you can just simply do inference without pytorch stuff, collect samples, then transform samples using your pytorch model.

Please see the amended code below - I do need a likelihood in the model, which relies on the transformed random variable. The transformation is being performed by the PyTorch model:

# Numpyro model
def numpyro_model(y, z_dim=4):
    z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim)))
    z_np = np.asarray(z)
    z_torch = torch.from_numpy(z_np)
    with torch.no_grad():
        f = numpyro.deterministic("f", model(z_torch))    

    # likelihood
    y = numpyro.sample("y", dist.Normal(f, 0.1), obs=y)

# Inference
kernel = NUTS(numpyro_model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(rng_key=random.PRNGKey(2), z_dim=4, y = jnp.array([5.,4.,3.,2.]))

f you don’t need grad then you can use host_callback

I have tried this option. Getting the error TypeError: 'module' object is not callable. The full code below. Any hints?

import torch
import torch.nn as nn

import jax
from jax.experimental import host_callback
import jax.numpy as jnp
from jax import random

import numpy as np

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS


def z_to_f(z):
    z_np = np.asarray(z)
    z_torch = torch.from_numpy(z_np)  
    with torch.no_grad():
        f =  model(z_torch)
    return jnp.asarray(f)

# Numpyro model
def numpyro_model(y, z_dim=4):
    z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim)))
    f = jax.experimental.host_callback(z_to_f, z)   
    y = numpyro.sample("y", dist.Normal(f, 0.1), obs=y)

# Inference
kernel = NUTS(numpyro_model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(rng_key=random.PRNGKey(2), z_dim=4, y = jnp.array([5.,4.,3.,2.]))

I’m not sure what causes the error but I think you need grad if you use NUTS. We have some sampler that does not require grad like SA sampler.

Btw, for likelihood, you might want to use sample(..., obs=y)

Btw, for likelihood, you might want to use sample(..., obs=y)

sure, thanks, my mistake, have updated above

I’m not sure what causes the error but I think you need grad if you use NUTS. We have some sampler that does not require grad like SA sampler.

Have tried SA and getting the same error :frowning:

Did you test

f = jax.experimental.host_callback(z_to_f, z)

work under jit, like running

f = jax.jit(lamba x: jax.experimental.host_callback(z_to_f, x))(z)

outside of numpyro stuff.

both version produce TypeError: 'module' object is not callable

Here’s the code which I tested:

import jax
from jax.experimental import host_callback
import jax.numpy as jnp
import torch.nn as nn

model = nn.Linear(4, 4)

def z_to_f(z):
    z_np = np.asarray(z)
    z_torch = torch.from_numpy(z_np)  
    with torch.no_grad():
        f =  model(z_torch)
    return jnp.asarray(f)

z =  jnp.array([5.,4.,3.,2.])
f = jax.experimental.host_callback(z_to_f, z)

I think you need host_callback.call or something. host_callback is a module, hence you have the error. Maybe jax.pure_callback is easier to use for you?

1 Like

you were right - I was missing call, as well as the result_shape parameter. Below is the working version of host_callback. However, it is still not working within Numpyro.

Working host_callback example:

import jax
from jax.experimental import host_callback
import jax.numpy as jnp
import torch
import torch.nn as nn
import numpy as np

def z_to_f(z):
    model = nn.Linear(4, 4)
    z_np = np.asarray(z)
    z_torch = torch.from_numpy(z_np)  
    with torch.no_grad():
        f =  model(z_torch)
    return jnp.asarray(f)

z = jnp.array([5.,4.,3.,2.])
result_shape = jax.ShapeDtypeStruct(z.shape, jnp.result_type(float))
f = jax.experimental.host_callback.call(z_to_f, z, result_shape=result_shape)

Breaking Numpyro call, with the error
NotImplementedError: JVP rule is implemented only for id_tap, not for call:

def numpyro_model(y, z_dim=4):
    z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim)))
    with torch.no_grad():
        result_shape = jax.ShapeDtypeStruct(y.shape, jnp.result_type(float))
        f = numpyro.deterministic("f", jax.experimental.host_callback.call(z_to_f, z, result_shape=result_shape))    
    y = numpyro.sample("y", dist.Normal(f, 0.1), obs=y)

# Inference
kernel = NUTS(numpyro_model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(rng_key=random.PRNGKey(2), z_dim=4, y = jnp.array([5.,4.,3.,2.]))

@fehiepsi

1). Regarding jax.pure_callback: I am unable to make it work for the scipy example. What am I missing?

working example:

x = jnp.arange(5.0)
result_shape = jax.core.ShapedArray(x.shape, x.dtype)
jax.pure_callback(np.sin, result_shape, x)

failing example (functions and arguments as in my previous reply):

result_shape = jax.core.ShapedArray(u.shape, u.dtype)
jax.pure_callback(scipy_truncated_poisson_icdf, result_shape, args)

2). Based on my reply above concerning jax.experimental.host_callback.call, is it worth a pull request?