Assign more weightage to recent years during training

Hello, I am trying to perform linear regression in the probabilistic framework. Since the data is high-dimensional, I am not doing a Bayesian inference but rather SVI. I am looking for a solution that gives more weight to the recent years. Will the below code do the needful?

years = jnp.array(dat['Year']) 

def get_weight(year):
    index = jnp.argmax(year_keys == year) 
    return jnp.where(year_keys[index] == year, year_values[index], 1.0)

weights = jax.vmap(get_weight)(years)

assert len(weights) == len(dat['Index']), "Weights and data size mismatch."

# Incorporate weights into the likelihood
with numpyro.plate("data", len(jnp.array(dat['Index']))):
    weighted_likelihood = weights * dist.Normal(enr_est, observed_sigma).log_prob(enr_obs)
    numpyro.factor("obs", weighted_likelihood.sum())

@fehiepsi @martinjankowiak I was wondering if either of you could take a look at this post when you get a chance. Thanks!

you should probably be doing sum(-1) since a scalar factor inside of a plate will be upbroadcasted

1 Like

or no sum at all. not clear what your tensor shapes are. the point is plates do automatic broadcasting so be careful

1 Like