Odd ConcretizationTypeError in poisson binomial model

I’m fairly new to NumPyro and am trying to get a model to work. The model is:

def precincts(racepop, count, pVecsLength, max_len_pVecs):
    poibin = pb.PoissonBinomial(count, pVecsLength)
    pW = numpyro.sample('pW', dist.Beta(1, 1))
    pH = numpyro.sample('pH', dist.Beta(1, 1))
    #Adding zero to end so 0's can be repeated and pad to equal length:
    pVals = jnp.stack((pW, pH, 0))
    pVecs = createPVecs(racepop, pVals, max_len_pVecs)
    numpyro.sample('obs', poibin, obs=pVecs)

Which has this utility function:

def createPVecs(racepop, pVals, max_len_pVecs):
    out = []
    for i in range(racepop.shape[0]):
        out.append( jnp.repeat(pVals,racepop[i], total_repeat_length = max_len_pVecs) )
    return out

The model, which includes my self-written distribution function, works fine until the progress bar shows–that is, it tries out a pair of pW and pH values and gets a log-likelihood. But once the progress bar appears, I get the following error:

Exception has occurred: ConcretizationTypeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=2/1)>
The problem arose with the `bool` function. 
While tracing the function _body_fn at /home/pb/Software/anaconda3/lib/python3.7/site-packages/numpyro/infer/hmc_util.py:1002 for while_loop, this value became a tracer due to JAX operations on these lines:

  operation a:i64[] = zeros_like b
    from line /home/pb/Software/anaconda3/lib/python3.7/site-packages/numpyro/distributions/transforms.py:832 (log_abs_det_jacobian)

This error occurs at line 4361 of lax_numpy.py (in the repeat function), which is: “if total_repeat_length == 0:” . According to my debugger the content of total_repeat_length when the error occurs is DeviceArray(6183, dtype=int64) . That is, it doesn’t seem to be a traced array. I’ve tried to declare the value that ends up in total_repeat_length static, but this has no effect.

The source of the problem in my code is the out.append(…) line (4th line of the createPVecs function). Looking at the variables that make up that line at the point of the error, the debugger indicates that racepop is a concrete jnp array of data, as I’d expect, max_len_pVecs is a single-valued concrete jnp array, but pVals is traced array. pVals is made up of pW and pH, which also are traced arrays and are given by numpyro.sample. I’ve tried to disable jit, but this has no effect. Perhaps there is no way to have a function like createPVecs that keeps JAX happy?

Why do you want to use jax arrays at those places (racepop[i], maxlength). Shouldn’t using Python scalar be enough?

Thanks for your idea fehiepsi. I was getting errors earlier with numpy arrays, so I changed them to jnp arrays. However, I went back and turned racepop back to numpy arrays and, while the np.max operation that creates maxlength gives a single-value array, I changed it to an integer.

Color me amazed, but the error I was getting disappeared. Not sure how that does the trick, but it does–many thanks!