Conditional and Discrete, SVI error

I am trying to run the following code, and receive this error. I am attempting to run SVI on a conditional model using a discrete distribution.

from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import SVI, Trace_ELBO, autoguide
numpyro.set_platform("cpu")
rng_key = random.PRNGKey(0)
data = dict()
data['x'] =np.array([62.1984,54.4534,80.2666,43.9811,9.5492,59.7737,47.7615,48.2980,11.5235,62.8889])
data['weight'] =10.0
data['bias'] =1.0
data['y'] =np.array([622.9842,545.5341,803.6663,440.8107,96.4920,598.7365,478.6153,483.9800,116.2349,629.8893])

def model(data):
    w=numpyro.sample("var1", dist.Normal(1.0,10.0))
    b=numpyro.sample("var2", dist.Normal(1.0,10.0))
    #cond = numpyro.sample("var3", dist.Bernoulli(0.5))
    cond=dist.Bernoulli(0.5).sample(rng_key)
    print(cond)
    if (cond <= 0):
        with numpyro.plate("size", np.size(data['y'])):
            numpyro.sample("obs", dist.Normal(w*data['x']+b,1.0*np.ones([10])), obs=data['y'])
            
    
    else:
        with numpyro.plate("size", np.size(data['y'])):
            numpyro.sample("obs", dist.Normal(w*data['x']+10.0,1.0*np.ones([10])), obs=data['y'])
            
    
guide = autoguide.AutoDiagonalNormal(model)
optimizer = numpyro.optim.Adam(0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 2000, data)
params = svi_result.params

I get the following error

jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The problem arose with the `bool` function. 

While tracing the function body_fn at /Users/ankithadamisetty/opt/anaconda3/lib/python3.7/site-packages/numpyro/infer/svi.py:171, this value became a tracer due to JAX operations on these lines:

and after some trace information I see the following

Encountered tracer value: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

The error originates from this line:

 svi_result = svi.run(rng_key, 2000, data)

you need to replace your if/else statement with jax.lax.cond. please read the jax faq.

you also can’t introduce randomness the way you’re doing with cond=dist.Bernoulli(0.5).sample(rng_key). i believe this will just evaluate to a constant value for cond.