Hello, I am trying to perform linear regression in the probabilistic framework. Since the data is high-dimensional, I am not doing a Bayesian inference but rather SVI. I am looking for a solution that gives more weight to the recent years. Will the below code do the needful?
years = jnp.array(dat['Year'])
def get_weight(year):
index = jnp.argmax(year_keys == year)
return jnp.where(year_keys[index] == year, year_values[index], 1.0)
weights = jax.vmap(get_weight)(years)
assert len(weights) == len(dat['Index']), "Weights and data size mismatch."
# Incorporate weights into the likelihood
with numpyro.plate("data", len(jnp.array(dat['Index']))):
weighted_likelihood = weights * dist.Normal(enr_est, observed_sigma).log_prob(enr_obs)
numpyro.factor("obs", weighted_likelihood.sum())
@fehiepsi @martinjankowiak I was wondering if either of you could take a look at this post when you get a chance. Thanks!