Jax weak_type error

I am writing the following program in numpyro. I am using jax condition functions where I am modifying certain parameters of my model.

import jax
from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import SVI, Trace_ELBO, autoguide
numpyro.set_platform("cpu")
data = dict()
data['offers'] =np.array([15,16])

def model(data):
    param=numpyro.sample("var1", dist.Uniform(20,50))
    Recruiters=numpyro.sample("var2", dist.Poisson(param))
    percentile=numpyro.sample("var3", dist.Uniform(0,1))
    Interviews=0
    GPA=0.0
    operand_out=jax.lax.cond(percentile>0.95,true0,false0,operand={'data':data,'percentile':percentile,'param':param,'GPA':GPA,'Recruiters':Recruiters,'Interviews':Interviews})
    for key, value in operand_out.items(): 
        value = operand_out[key]    
    operand_out=jax.lax.cond(GPA == 4,true1,false1,operand={'data':data,'percentile':percentile,'param':param,'GPA':GPA,'Recruiters':Recruiters,'Interviews':Interviews})
    for key, value in operand_out.items(): 
        value = operand_out[key]    
    operand_out=jax.lax.cond(GPA<4,true2,false2,operand={'data':data,'percentile':percentile,'param':param,'GPA':GPA,'Recruiters':Recruiters,'Interviews':Interviews})
    for key, value in operand_out.items(): 
        value = operand_out[key]    
    for n in range(1,2):
        with numpyro.plate("size", np.size(data['offers'])):
            numpyro.sample("obs13", dist.Binomial(Interviews,0.4), obs=data['offers'][n])
def true0(inp): 
    inp['GPA']=4
    return inp 
    
def false0(inp): 
    inp['GPA']=numpyro.sample("var4", dist.Normal(2.75,0.5))
    return inp 
    
def true1(inp): 
    inp['Interviews']=dist.Binomial(inp['Recruiters'],0.9).sample(random.PRNGKey(np.random.randint(100)))
    return inp 
    
def false1(inp): 
    return inp
    
def true2(inp): 
    inp['Interviews']=dist.Binomial(inp['Recruiters'],0.6).sample(random.PRNGKey(np.random.randint(100)))
    return inp 
    
def false2(inp): 
    return inp
    

guide = autoguide.AutoDiagonalNormal(model)
optimizer = numpyro.optim.Adam(0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 2000, data)
params = svi_result.params

I get the following error about a return type mismatch.

TypeError: true_fun and false_fun output must have identical types, got
{'GPA': ShapedArray(int32[], weak_type=True), 'Interviews': ShapedArray(int32[], weak_type=True), 'Recruiters': ShapedArray(int32[]), 'data': {'offers': ShapedArray(int32[2])}, 'param': ShapedArray(float32[]), 'percentile': ShapedArray(float32[])}
and
{'GPA': ShapedArray(float32[]), 'Interviews': ShapedArray(int32[], weak_type=True), 'Recruiters': ShapedArray(int32[]), 'data': {'offers': ShapedArray(int32[2])}, 'param': ShapedArray(float32[]), 'percentile': ShapedArray(float32[])}.
)

I am not sure how to fix this, my guess is it relates to when I modify a variable inside a true/false function.

Your branches return different dtype:

def true0(inp): 
    inp['GPA']=4
    return inp 
    
def false0(inp): 
    inp['GPA']=numpyro.sample("var4", dist.Normal(2.75,0.5))
    return inp

I guess you can change the true0 branch to inp['GPA']=4. to make both branches return a float dtype.

Anyway, primitives under cond is not supported yet. You can use mask (which is much simpler than using cond) to achieve the same thing:

var4 = numpyro.sample("var4", dist.Normal(2.75,0.5).mask(GPA != 4))
GPA = jnp.where(GPA == 4, GPA, var4)