Hello devs. I am reading Bayesian Regression Using NumPyro
In [15] code block where it’s defined
def log_likelihood(rng_key, params, model, *args, **kwargs):
model = handlers.condition(model, params)
model_trace = handlers.trace(model).get_trace(*args, **kwargs)
obs_node = model_trace["obs"]
return obs_node["fn"].log_prob(obs_node["value"])
def log_pred_density(rng_key, params, model, *args, **kwargs):
n = list(params.values())[0].shape[0]
log_lk_fn = vmap(
lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs)
)
log_lk_vals = log_lk_fn(random.split(rng_key, n), params)
return (logsumexp(log_lk_vals, 0) - jnp.log(n)).sum()
I noticed that log_likelihood
function nowhere uses the rng_key
argument. So how is the random number generator being determined here?