I’m fairly new to NumPyro and am trying to get a model to work. The model is:
def precincts(racepop, count, pVecsLength, max_len_pVecs): poibin = pb.PoissonBinomial(count, pVecsLength) pW = numpyro.sample('pW', dist.Beta(1, 1)) pH = numpyro.sample('pH', dist.Beta(1, 1)) #Adding zero to end so 0's can be repeated and pad to equal length: pVals = jnp.stack((pW, pH, 0)) pVecs = createPVecs(racepop, pVals, max_len_pVecs) numpyro.sample('obs', poibin, obs=pVecs)
Which has this utility function:
def createPVecs(racepop, pVals, max_len_pVecs): out =  for i in range(racepop.shape): out.append( jnp.repeat(pVals,racepop[i], total_repeat_length = max_len_pVecs) ) return out
The model, which includes my self-written distribution function, works fine until the progress bar shows–that is, it tries out a pair of pW and pH values and gets a log-likelihood. But once the progress bar appears, I get the following error:
Exception has occurred: ConcretizationTypeError (note: full exception trace is shown but execution is paused at: _run_module_as_main) Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool)>with<DynamicJaxprTrace(level=2/1)> The problem arose with the `bool` function. While tracing the function _body_fn at /home/pb/Software/anaconda3/lib/python3.7/site-packages/numpyro/infer/hmc_util.py:1002 for while_loop, this value became a tracer due to JAX operations on these lines: operation a:i64 = zeros_like b from line /home/pb/Software/anaconda3/lib/python3.7/site-packages/numpyro/distributions/transforms.py:832 (log_abs_det_jacobian)
This error occurs at line 4361 of lax_numpy.py (in the repeat function), which is: “if total_repeat_length == 0:” . According to my debugger the content of total_repeat_length when the error occurs is DeviceArray(6183, dtype=int64) . That is, it doesn’t seem to be a traced array. I’ve tried to declare the value that ends up in total_repeat_length static, but this has no effect.
The source of the problem in my code is the out.append(…) line (4th line of the createPVecs function). Looking at the variables that make up that line at the point of the error, the debugger indicates that racepop is a concrete jnp array of data, as I’d expect, max_len_pVecs is a single-valued concrete jnp array, but pVals is traced array. pVals is made up of pW and pH, which also are traced arrays and are given by numpyro.sample. I’ve tried to disable jit, but this has no effect. Perhaps there is no way to have a function like createPVecs that keeps JAX happy?