TraceConversionError, unsure how to see where issue happens

Hi all,

Am trying to set up a model in numpyro I keep getting trace conversion errors. I thought it was due to NA/Inf/-Inf values but it seems to persist. Tried changing all my numpy calls to jax.numpy.array() but that doesn’t fix it either.

I suspect I’m doing something rather dumb but have no idea where to start.

For context, X is a dataframe where all of my inputs are dummy variables except for the y-variable below

import numpy as np
import pandas as pd
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
import jax


df = pd.read_csv('modelling.data1b.csv')

  # growth + b1*phase1 + b2*phase2 + b3*abx1 + b4*abx2 + b5 + (1 | prophylactic) + (1 | patient) + b6*anaerobes.  
def hierarchical_regression(X, ylab= 'log.delta.anaerobes'):
    X = X.dropna()
    prophylactic = X['prophylaxis'].tolist()
    # patient = X['timeline_id'].tolist()
    patient = X['modelling.id'].tolist()

    y = X[ylab]

    num_prophylactics = np.unique(prophylactic).shape[0]
    num_patients = np.unique(patient).shape[0]
    # Priors
    # a = numpyro.sample("a", dist.Normal(0.0, 0.2))
    # mu = numpyro.sample('mu', dist.Normal(0, 5))

    b1 = numpyro.sample('b1', dist.Normal(0,100)) # phase 1
    b2 = numpyro.sample('b2', dist.Normal(0,100)) # phase 2
    b3 = numpyro.sample('b3', dist.Normal(0,100)) # pip
    b4 = numpyro.sample('b4', dist.Normal(0,100)) # mero
    b5 = numpyro.sample('b5', dist.Normal(0,100)) # metro 
    b6 = numpyro.sample('b6', dist.Normal(0,100)) # ceph 1-3
    b7 = numpyro.sample('b7', dist.Normal(0,100)) # vanco po
    b8 = numpyro.sample('b8', dist.Normal(0,100)) # cefepime 
    b9 = numpyro.sample('b9', dist.Normal(0,100)) # linez

    b10 = numpyro.sample('b10', dist.Normal(0,100)) # fluor
    b11 = numpyro.sample('b11', dist.Normal(0,100)) # vanco iv
    b12 = numpyro.sample('b12', dist.Normal(0,100)) # sulfa
    b13 = numpyro.sample('b13', dist.Normal(0,100)) # atova

    b15 = numpyro.sample('b15', dist.Normal(0, 0.1)) 
    b16 = numpyro.sample('b16', dist.Normal(0, 0.1))
    sigma = numpyro.sample('sigma', dist.Normal(0,.1))
    mu1 = jax.numpy.zeros_like(jax.numpy.array(y.tolist()))

    with numpyro.plate('prophylactic_plate', num_prophylactics):
        b_prophylactic = numpyro.sample('b_prophylactic', dist.Normal(0, 0.1))
        intercept_prophylactic = numpyro.sample('intercept_prophylactic', dist.Normal(0, 0.1))
        # mu1 = mu1 + b_prophylactic[np.array(prophylactic)] * numpyro.one_hot(prophylactic, num_prophylactics) + intercept_prophylactic[np.array(prophylactic)]
        mu1 = mu1 + intercept_prophylactic[jax.numpy.array(prophylactic)]

    with numpyro.plate('patient_plate', num_patients):
        b_patient = numpyro.sample('b_patient', dist.Normal(0, 0.1))
        intercept_patient = numpyro.sample('intercept_patient', dist.Normal(0, 0.1))
        # mu1 = mu1 + b_patient[patient] * numpyro.one_hot(patient, num_patients) + intercept_patient[patient]
        mu1 = mu1 + intercept_patient[jax.numpy.array(patient)]


    # Linear predictor
    mu1 =  intercept_prophylactic[jax.numpy.array(prophylactic)] + intercept_patient[jax.numpy.array(patient)] + jax.numpy.array(X['intrinsic.growth.rate'].values) + b2*jax.numpy.array(X['phase1'].values) + b3*jax.numpy.array(X['phase2'].values) + b4*jax.numpy.array(X['meropenem'].values) + b5*jax.numpy.array(X['metronidazole'].values) + b6*jax.numpy.array(X['cef13'].values) + b7*jax.numpy.array(X['vanco_po'].values) + b8*jax.numpy.array(X['cefepime'].values) + b9*jax.numpy.array(X['linezolid'].values) + b10*jax.numpy.array(X['fluoroquinolones'].values) + b11*jax.numpy.array(X['vanco_iv'].values) + b12*jax.numpy.array(X['sulfa_trim'].values) + b13*jax.numpy.array(X['atovaquone'].values) + b15 + b16*jax.numpy.array(X['previous.anaerobes'].values) #X[:, 4]

    # Likelihood
    numpyro.sample('y', dist.Normal(mu1, sigma), obs=y)
    
    return mu1

nuts_kernel = NUTS(hierarchical_regression, target_accept_prob=0.97)

mcmc = MCMC(nuts_kernel, num_samples=3000, num_warmup=5000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, df, ylab = 'log.delta.anaerobes') 

posterior_samples = mcmc.get_samples()

My error is as follows:

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
/usr/local/lib/python3.8/dist-packages/pandas/core/construction.py in sanitize_array(data, index, dtype, copy, raise_cast_failure, allow_2d)
    562         if hasattr(data, "__array__"):
    563             # e.g. dask array GH#38645
--> 564             data = np.asarray(data)
    565         else:
    566             data = list(data)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([ -9.5061455  -10.536966    -8.111055   ...   0.15791154  -0.2573204
  -0.2852156 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([ -9.5061455 , -10.536966  ,  -8.111055  , ...,   0.15791154,
        -0.2573204 ,  -0.2852156 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[9221])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[9221]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f8110cfd610>, in_tracers=(Traced<ShapedArray(float32[9221]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f811370f590; to 'JaxprTracer' at 0x7f811370fe50>], out_avals=[ShapedArray(float32[9221])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[9221]. let b:f32[9221] = neg a in (b,) }, 'in_shardings': (<jax._src.interpreters.pxla.UnspecifiedValue object at 0x7f8138cc20a0>,), 'out_shardings': (<jax._src.interpreters.pxla.UnspecifiedValue object at 0x7f8138cc20a0>,), 'resource_env': None, 'donated_invars': (False,), 'name': '<lambda>', 'in_positional_semantics': (<_PositionalSemantics.GLOBAL: 1>,), 'out_positional_semantics': <_PositionalSemantics.GLOBAL: 1>, 'keep_unused': False, 'inline': True}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f81116954f0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Any suggestions much appreciated!

don’t know what’s going on but i suggest moving all pandas manipulation outside of the model and passing the model pure numpy data.

I changed it so that I’m passing jnp arrays but am still getting the same error - I’m not even sure where pandas is being called in the code since I’ve removed all references to it in the function

# import numpy as np
import jax.numpy as jnp
import pandas as pd
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
import jax


df = pd.read_csv('modelling.data1b.csv')
df = df.dropna()
prophylactic = df['prophylaxis'].tolist()
# patient = X['timeline_id'].tolist()
patient = df['modelling.id'].tolist()

y = df['log.delta.anaerobes']

num_prophylactics = 2
num_patients = 1094 # unique vals
prophylactic = df['prophylaxis'].tolist()
patient = df['modelling.id'].tolist()
    # mu1 =  intercept_prophylactic[jax.numpy.array(prophylactic)] + intercept_patient[jax.numpy.array(patient)] + jax.numpy.array(X['intrinsic.growth.rate'].values) + b2*jax.numpy.array(X['phase1'].values) + b3*jax.numpy.array(X['phase2'].values) + b4*jax.numpy.array(X['meropenem'].values) + b5*jax.numpy.array(X['metronidazole'].values) + b6*jax.numpy.array(X['cef13'].values) + b7*jax.numpy.array(X['vanco_po'].values) + b8*jax.numpy.array(X['cefepime'].values) + b9*jax.numpy.array(X['linezolid'].values) + b10*jax.numpy.array(X['fluoroquinolones'].values) + b11*jax.numpy.array(X['vanco_iv'].values) + b12*jax.numpy.array(X['sulfa_trim'].values) + b13*jax.numpy.array(X['atovaquone'].values) + b15 + b16*jax.numpy.array(X['previous.anaerobes'].values) #X[:, 4]
intrinsic_growth_rate = jnp.array(df['intrinsic.growth.rate'].values)
phase1 = jnp.array(df['phase1'].values)
phase2 = jnp.array(df['phase2'].values)
meropenem = jnp.array(df['meropenem'].values)
metronidazole = jnp.array(df['metronidazole'].values)
cef13 = jnp.array(df['cef13'].values)
vanco_po = jnp.array(df['vanco_po'].values)
cefepime = jnp.array(df['cefepime'].values)
linezolid = jnp.array(df['linezolid'].values)

fluoroquinolones = jnp.array(df['fluoroquinolones'].values)
vanco_iv = jnp.array(df['vanco_iv'].values)
sulfa_trim = jnp.array(df['sulfa_trim'].values)
atovaquone = jnp.array(df['atovaquone'].values)
previous_anaerobes = jnp.array(df['previous.anaerobes'].values)


  # growth + b1*phase1 + b2*phase2 + b3*abx1 + b4*abx2 + b5 + (1 | prophylactic) + (1 | patient) + b6*anaerobes.  
def hierarchical_regression(y, intrinsic_growth_rate,
                            phase1, phase2, meropenem, metronidazole,
                            cef13, vanco_po, cefepime, linezolid, fluoroquinolones,
                            vanco_iv, sulfa_trim, atovaquone, previous_anaerobes,
                            prophylactic, patient,  num_prophylactics, num_patients): #ylab= 'log.delta.anaerobes'):
    # print(X.head(2))
    # X = X.dropna()


    # y = X[ylab]

    # num_prophylactics = np.unique(prophylactic).shape[0]
    # num_patients = np.unique(patient).shape[0]
    # Priors
    # a = numpyro.sample("a", dist.Normal(0.0, 0.2))
    # mu = numpyro.sample('mu', dist.Normal(0, 5))

    b1 = numpyro.sample('b1', dist.Normal(0,100)) # phase 1
    b2 = numpyro.sample('b2', dist.Normal(0,100)) # phase 2
    b3 = numpyro.sample('b3', dist.Normal(0,100)) # pip
    b4 = numpyro.sample('b4', dist.Normal(0,100)) # mero
    b5 = numpyro.sample('b5', dist.Normal(0,100)) # metro 
    b6 = numpyro.sample('b6', dist.Normal(0,100)) # ceph 1-3
    b7 = numpyro.sample('b7', dist.Normal(0,100)) # vanco po
    b8 = numpyro.sample('b8', dist.Normal(0,100)) # cefepime 
    b9 = numpyro.sample('b9', dist.Normal(0,100)) # linez

    b10 = numpyro.sample('b10', dist.Normal(0,100)) # fluor
    b11 = numpyro.sample('b11', dist.Normal(0,100)) # vanco iv
    b12 = numpyro.sample('b12', dist.Normal(0,100)) # sulfa
    b13 = numpyro.sample('b13', dist.Normal(0,100)) # atova

    b15 = numpyro.sample('b15', dist.Normal(0, 0.1)) 
    b16 = numpyro.sample('b16', dist.Normal(0, 0.1))
    sigma = numpyro.sample('sigma', dist.Normal(0,.1))
    mu1 = jax.numpy.zeros_like(jax.numpy.array(y.tolist()))

    with numpyro.plate('prophylactic_plate', num_prophylactics):
        b_prophylactic = numpyro.sample('b_prophylactic', dist.Normal(0, 0.1))
        intercept_prophylactic = numpyro.sample('intercept_prophylactic', dist.Normal(0, 0.1))

        mu1 = mu1 + intercept_prophylactic[jax.numpy.array(prophylactic)]

    with numpyro.plate('patient_plate', num_patients):
        b_patient = numpyro.sample('b_patient', dist.Normal(0, 0.1))
        intercept_patient = numpyro.sample('intercept_patient', dist.Normal(0, 0.1))

        mu1 = mu1 + intercept_patient[jax.numpy.array(patient)]


    # Linear predictor
   
    mu1 =  intercept_prophylactic[jax.numpy.array(prophylactic)] + intercept_patient[jax.numpy.array(patient)] + jax.numpy.array(intrinsic_growth_rate) + b2*jax.numpy.array(phase1) + b3*jax.numpy.array(phase2) + b4*jax.numpy.array(meropenem) + b5*jax.numpy.array(metronidazole) + b6*jax.numpy.array(cef13) + b7*jax.numpy.array(vanco_po) + b8*jax.numpy.array(cefepime) + b9*jax.numpy.array(linezolid) + b10*jax.numpy.array(fluoroquinolones) + b11*jax.numpy.array(vanco_iv) + b12*jax.numpy.array(sulfa_trim) + b13*jax.numpy.array(atovaquone) + b15 + b16*jax.numpy.array(previous_anaerobes) #X[:, 4]

    # Likelihood
    numpyro.sample('y', dist.Normal(mu1, sigma), obs=y)
    
    return mu1

nuts_kernel = NUTS(hierarchical_regression, target_accept_prob=0.97)

mcmc = MCMC(nuts_kernel, num_samples=3000, num_warmup=5000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y, intrinsic_growth_rate,
                            phase1, phase2, meropenem, metronidazole,
                            cef13, vanco_po, cefepime, linezolid, fluoroquinolones,
                            vanco_iv, sulfa_trim, atovaquone, previous_anaerobes, prophylactic, patient, num_prophylactics, num_patients) 

posterior_samples = mcmc.get_samples()

Just want to comment - I was wrong! Y was not converted properly to jax array!

Thanks!!