Hi all,
Am trying to set up a model in numpyro I keep getting trace conversion errors. I thought it was due to NA/Inf/-Inf values but it seems to persist. Tried changing all my numpy calls to jax.numpy.array() but that doesn’t fix it either.
I suspect I’m doing something rather dumb but have no idea where to start.
For context, X is a dataframe where all of my inputs are dummy variables except for the y-variable below
import numpy as np
import pandas as pd
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
import jax
df = pd.read_csv('modelling.data1b.csv')
# growth + b1*phase1 + b2*phase2 + b3*abx1 + b4*abx2 + b5 + (1 | prophylactic) + (1 | patient) + b6*anaerobes.
def hierarchical_regression(X, ylab= 'log.delta.anaerobes'):
X = X.dropna()
prophylactic = X['prophylaxis'].tolist()
# patient = X['timeline_id'].tolist()
patient = X['modelling.id'].tolist()
y = X[ylab]
num_prophylactics = np.unique(prophylactic).shape[0]
num_patients = np.unique(patient).shape[0]
# Priors
# a = numpyro.sample("a", dist.Normal(0.0, 0.2))
# mu = numpyro.sample('mu', dist.Normal(0, 5))
b1 = numpyro.sample('b1', dist.Normal(0,100)) # phase 1
b2 = numpyro.sample('b2', dist.Normal(0,100)) # phase 2
b3 = numpyro.sample('b3', dist.Normal(0,100)) # pip
b4 = numpyro.sample('b4', dist.Normal(0,100)) # mero
b5 = numpyro.sample('b5', dist.Normal(0,100)) # metro
b6 = numpyro.sample('b6', dist.Normal(0,100)) # ceph 1-3
b7 = numpyro.sample('b7', dist.Normal(0,100)) # vanco po
b8 = numpyro.sample('b8', dist.Normal(0,100)) # cefepime
b9 = numpyro.sample('b9', dist.Normal(0,100)) # linez
b10 = numpyro.sample('b10', dist.Normal(0,100)) # fluor
b11 = numpyro.sample('b11', dist.Normal(0,100)) # vanco iv
b12 = numpyro.sample('b12', dist.Normal(0,100)) # sulfa
b13 = numpyro.sample('b13', dist.Normal(0,100)) # atova
b15 = numpyro.sample('b15', dist.Normal(0, 0.1))
b16 = numpyro.sample('b16', dist.Normal(0, 0.1))
sigma = numpyro.sample('sigma', dist.Normal(0,.1))
mu1 = jax.numpy.zeros_like(jax.numpy.array(y.tolist()))
with numpyro.plate('prophylactic_plate', num_prophylactics):
b_prophylactic = numpyro.sample('b_prophylactic', dist.Normal(0, 0.1))
intercept_prophylactic = numpyro.sample('intercept_prophylactic', dist.Normal(0, 0.1))
# mu1 = mu1 + b_prophylactic[np.array(prophylactic)] * numpyro.one_hot(prophylactic, num_prophylactics) + intercept_prophylactic[np.array(prophylactic)]
mu1 = mu1 + intercept_prophylactic[jax.numpy.array(prophylactic)]
with numpyro.plate('patient_plate', num_patients):
b_patient = numpyro.sample('b_patient', dist.Normal(0, 0.1))
intercept_patient = numpyro.sample('intercept_patient', dist.Normal(0, 0.1))
# mu1 = mu1 + b_patient[patient] * numpyro.one_hot(patient, num_patients) + intercept_patient[patient]
mu1 = mu1 + intercept_patient[jax.numpy.array(patient)]
# Linear predictor
mu1 = intercept_prophylactic[jax.numpy.array(prophylactic)] + intercept_patient[jax.numpy.array(patient)] + jax.numpy.array(X['intrinsic.growth.rate'].values) + b2*jax.numpy.array(X['phase1'].values) + b3*jax.numpy.array(X['phase2'].values) + b4*jax.numpy.array(X['meropenem'].values) + b5*jax.numpy.array(X['metronidazole'].values) + b6*jax.numpy.array(X['cef13'].values) + b7*jax.numpy.array(X['vanco_po'].values) + b8*jax.numpy.array(X['cefepime'].values) + b9*jax.numpy.array(X['linezolid'].values) + b10*jax.numpy.array(X['fluoroquinolones'].values) + b11*jax.numpy.array(X['vanco_iv'].values) + b12*jax.numpy.array(X['sulfa_trim'].values) + b13*jax.numpy.array(X['atovaquone'].values) + b15 + b16*jax.numpy.array(X['previous.anaerobes'].values) #X[:, 4]
# Likelihood
numpyro.sample('y', dist.Normal(mu1, sigma), obs=y)
return mu1
nuts_kernel = NUTS(hierarchical_regression, target_accept_prob=0.97)
mcmc = MCMC(nuts_kernel, num_samples=3000, num_warmup=5000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, df, ylab = 'log.delta.anaerobes')
posterior_samples = mcmc.get_samples()
My error is as follows:
The above exception was the direct cause of the following exception:
TracerArrayConversionError Traceback (most recent call last)
/usr/local/lib/python3.8/dist-packages/pandas/core/construction.py in sanitize_array(data, index, dtype, copy, raise_cast_failure, allow_2d)
562 if hasattr(data, "__array__"):
563 # e.g. dask array GH#38645
--> 564 data = np.asarray(data)
565 else:
566 data = list(data)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([ -9.5061455 -10.536966 -8.111055 ... 0.15791154 -0.2573204
-0.2852156 ], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([ -9.5061455 , -10.536966 , -8.111055 , ..., 0.15791154,
-0.2573204 , -0.2852156 ], dtype=float32)
tangent = Traced<ShapedArray(float32[9221])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[9221]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f8110cfd610>, in_tracers=(Traced<ShapedArray(float32[9221]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f811370f590; to 'JaxprTracer' at 0x7f811370fe50>], out_avals=[ShapedArray(float32[9221])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[9221]. let b:f32[9221] = neg a in (b,) }, 'in_shardings': (<jax._src.interpreters.pxla.UnspecifiedValue object at 0x7f8138cc20a0>,), 'out_shardings': (<jax._src.interpreters.pxla.UnspecifiedValue object at 0x7f8138cc20a0>,), 'resource_env': None, 'donated_invars': (False,), 'name': '<lambda>', 'in_positional_semantics': (<_PositionalSemantics.GLOBAL: 1>,), 'out_positional_semantics': <_PositionalSemantics.GLOBAL: 1>, 'keep_unused': False, 'inline': True}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f81116954f0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Any suggestions much appreciated!