In pyro
, I am able to scale the log probability
with plate('observe_data'), pyro.poutine.scale(scale=scale_factors):
pyro.sample("obs", dist.Poisson(lam), obs=values)
Where scale_factors
is a tensor the same size as values
. It’s unclear to me how to implement something similar with numpyro
. When I try to use numpyro.handlers.scale
in place of poutine.scale
in a numpyro model, I get the following error
523 def __init__(self, fn=None, scale=1.):
524 if not_jax_tracer(scale):
--> 525 if scale <= 0:
526 raise ValueError("'scale' argument should be a positive number.")
527 self.scale = scale
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Is it possible to use a vector for scaling in numpyro
?