I am writing the following program in numpyro. I am using jax condition functions where I am modifying certain parameters of my model.
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import SVI, Trace_ELBO, autoguide
numpyro.set_platform("cpu")
data = dict()
data['offers'] =np.array([15,16])
def model(data):
param=numpyro.sample("var1", dist.Uniform(20,50))
Recruiters=numpyro.sample("var2", dist.Poisson(param))
percentile=numpyro.sample("var3", dist.Uniform(0,1))
Interviews=0
GPA=0.0
operand_out=jax.lax.cond(percentile>0.95,true0,false0,operand={'data':data,'percentile':percentile,'param':param,'GPA':GPA,'Recruiters':Recruiters,'Interviews':Interviews})
for key, value in operand_out.items():
value = operand_out[key]
operand_out=jax.lax.cond(GPA == 4,true1,false1,operand={'data':data,'percentile':percentile,'param':param,'GPA':GPA,'Recruiters':Recruiters,'Interviews':Interviews})
for key, value in operand_out.items():
value = operand_out[key]
operand_out=jax.lax.cond(GPA<4,true2,false2,operand={'data':data,'percentile':percentile,'param':param,'GPA':GPA,'Recruiters':Recruiters,'Interviews':Interviews})
for key, value in operand_out.items():
value = operand_out[key]
for n in range(1,2):
with numpyro.plate("size", np.size(data['offers'])):
numpyro.sample("obs13", dist.Binomial(Interviews,0.4), obs=data['offers'][n])
def true0(inp):
inp['GPA']=4
return inp
def false0(inp):
inp['GPA']=numpyro.sample("var4", dist.Normal(2.75,0.5))
return inp
def true1(inp):
inp['Interviews']=dist.Binomial(inp['Recruiters'],0.9).sample(random.PRNGKey(np.random.randint(100)))
return inp
def false1(inp):
return inp
def true2(inp):
inp['Interviews']=dist.Binomial(inp['Recruiters'],0.6).sample(random.PRNGKey(np.random.randint(100)))
return inp
def false2(inp):
return inp
guide = autoguide.AutoDiagonalNormal(model)
optimizer = numpyro.optim.Adam(0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 2000, data)
params = svi_result.params
I get the following error about a return type mismatch.
TypeError: true_fun and false_fun output must have identical types, got
{'GPA': ShapedArray(int32[], weak_type=True), 'Interviews': ShapedArray(int32[], weak_type=True), 'Recruiters': ShapedArray(int32[]), 'data': {'offers': ShapedArray(int32[2])}, 'param': ShapedArray(float32[]), 'percentile': ShapedArray(float32[])}
and
{'GPA': ShapedArray(float32[]), 'Interviews': ShapedArray(int32[], weak_type=True), 'Recruiters': ShapedArray(int32[]), 'data': {'offers': ShapedArray(int32[2])}, 'param': ShapedArray(float32[]), 'percentile': ShapedArray(float32[])}.
)
I am not sure how to fix this, my guess is it relates to when I modify a variable inside a true/false function.