Error in getting posterior predictive samples

Hello,

I have the following code; model seems to run fine but when I get to the last line, it throws an error which I don’t understand, also below. This follows the Bayesian regression example; any ideas as to what might be happening? Thank you in advance. (Side note: I’m on the current version of numpyro and jax).


Nevermind! I just had to call

# Forecast
predictive     = Predictive(BQR, posterior_samples=ps,return_sites=["y"])
y_pred_samples = predictive(random.PRNGKey(0), tau=0.5, X=x_test.to_numpy())["y"]