I am trying to run the following code, and receive this error. I am attempting to run SVI on a conditional model using a discrete distribution.
from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import SVI, Trace_ELBO, autoguide
numpyro.set_platform("cpu")
rng_key = random.PRNGKey(0)
data = dict()
data['x'] =np.array([62.1984,54.4534,80.2666,43.9811,9.5492,59.7737,47.7615,48.2980,11.5235,62.8889])
data['weight'] =10.0
data['bias'] =1.0
data['y'] =np.array([622.9842,545.5341,803.6663,440.8107,96.4920,598.7365,478.6153,483.9800,116.2349,629.8893])
def model(data):
w=numpyro.sample("var1", dist.Normal(1.0,10.0))
b=numpyro.sample("var2", dist.Normal(1.0,10.0))
#cond = numpyro.sample("var3", dist.Bernoulli(0.5))
cond=dist.Bernoulli(0.5).sample(rng_key)
print(cond)
if (cond <= 0):
with numpyro.plate("size", np.size(data['y'])):
numpyro.sample("obs", dist.Normal(w*data['x']+b,1.0*np.ones([10])), obs=data['y'])
else:
with numpyro.plate("size", np.size(data['y'])):
numpyro.sample("obs", dist.Normal(w*data['x']+10.0,1.0*np.ones([10])), obs=data['y'])
guide = autoguide.AutoDiagonalNormal(model)
optimizer = numpyro.optim.Adam(0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 2000, data)
params = svi_result.params
I get the following error
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The problem arose with the `bool` function.
While tracing the function body_fn at /Users/ankithadamisetty/opt/anaconda3/lib/python3.7/site-packages/numpyro/infer/svi.py:171, this value became a tracer due to JAX operations on these lines:
and after some trace information I see the following
Encountered tracer value: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
The error originates from this line:
svi_result = svi.run(rng_key, 2000, data)