Nuts Model works beautifully but SVI log-lik/obs is insane

Hi guys,

I’m a beginner to all ML/NN/ai. I have a rough understanding from college but nothing super in depth.
So I’ve been working on a model for predicting nba 3pt%. A dataset that has high dispersion and high variance. Anyways Nuts r_hat looks find but SVI doesn’t produce similar params or converge properly unless I have super super tight priors based off of nuts. Also like my title sayss the log-lik/obs is obscenely high even with all of this. Like 2-3x nuts log-lik/obs. My nuts uses non-centered params and i’ve tried both centered and non centered for svi. I’ve tried training on the full train dataset for svi I’ve tried on the subsample. It was proposed that maybe some of my features were too large in numerical size compared to others and so to scale all features down. and I also tried that.

Here is the code in question.

I’ve tried asking every ai what could be wrong with my code and they all gave alot of solutions that helped on the margins but non were great.

def hierarchical_pct_model_svi_centered(X,*,makes,attempts,player_idx,team_season_idx,opp_team_season_idx,n_players: int,n_teams: int,n_opps: int,total_obs:int,coef_scale: float = 1.0,):if numpyro is None or dist is None or jnp is None:raise ImportError(“NumPyro + JAX required”)
X_arr = jnp.asarray(X) batch_size, n_features = X_arr.shape  attempts_arr = jnp.asarray(attempts) makes_arr = jnp.asarray(makes) player_idx_arr = jnp.asarray(player_idx, dtype=jnp.int32) team_idx_arr = jnp.asarray(team_season_idx, dtype=jnp.int32) opp_idx_arr = jnp.asarray(opp_team_season_idx, dtype=jnp.int32)  
# You can also infer these instead of passing if you want: 
# n_players = int(player_idx_arr.max()) + 1 # n_teams   =int(team_idx_arr.max()) + 1 
# n_opps    = int(opp_idx_arr.max()) + 1  
alpha_global = numpyro.sample("alpha_global", dist.Normal(logit(0.36), 0.2))   #Tighter priors, in the ballpark of your NUTS posterior 
sigma_player = numpyro.sample("sigma_player", dist.HalfNormal(0.03)) 
sigma_team   = numpyro.sample("sigma_team",   dist.HalfNormal(0.03)) 
sigma_opp    = numpyro.sample("sigma_opp",    dist.HalfNormal(0.03))  
with numpyro.plate("players", n_players):     alpha_player_raw = numpyro.sample(     "alpha_player_raw",     dist.Normal(0.0, 1.0)     )     alpha_player = numpyro.deterministic(     "alpha_player",     sigma_player * alpha_player_raw     ) with numpyro.plate("team_seasons", n_teams):     beta_team_raw = numpyro.sample("beta_team_raw", dist.Normal(0.0, 1.0))     beta_team = numpyro.deterministic("beta_team", sigma_team * beta_team_raw) with numpyro.plate("opp_seasons", n_opps):     gamma_opp_raw = numpyro.sample("gamma_opp_raw", dist.Normal(0.0, 1.0))     gamma_opp = numpyro.deterministic("gamma_opp", sigma_opp * gamma_opp_raw)       beta = numpyro.sample(     "beta",     dist.Normal(0.0, coef_scale).expand([n_features]), )  eta_raw = ( alpha_global + jnp.dot(X_arr, beta) + alpha_player[player_idx_arr] + beta_team[team_idx_arr] + gamma_opp[opp_idx_arr] )  numpyro.deterministic("eta_raw", eta_raw) 
# (optional but sometimes helpful) mild clipping to prevent insane logits    
with numpyro.plate("obs", batch_size):     numpyro.sample(         "y",         dist.Binomial(total_count=attempts_arr, logits=eta_raw),         obs=makes_arr,     )
def run_svi_minibatch(
    arrays: HierarchicalPctArrays,
    batch_iter: Iterable[Mapping[str, object]],
    *,
    num_steps: int = 2000,
    learning_rate: float = 1e-3,
    rng_key=None,
    coef_scale: float = 1.0,
    init_loc_values: dict[str, object] | None = None,  # NEW

) -> SVI:
    """Fit the model with SVI using a minibatch iterator.

    Uses raw minibatches without additional ELBO scaling.
    """

    if SVI is None or autoguide is None or Adam is None or Trace_ELBO is None:  # pragma: no cover - import guard
        raise ImportError("NumPyro SVI dependencies are missing")

    n_players = len(arrays.player_id_lookup)
    n_teams = len(arrays.team_season_lookup)
    n_opps = len(arrays.opp_team_season_lookup)
    total_obs = len(arrays.makes)

    def model_fn(features, makes, attempts, player_idx, team_season_idx, opp_team_season_idx):
        return hierarchical_pct_model_svi_centered(
            features,
            makes=makes,
            attempts=attempts,
            player_idx=player_idx,
            team_season_idx=team_season_idx,
            opp_team_season_idx=opp_team_season_idx,
            n_players=n_players,
            n_teams=n_teams,
            n_opps=n_opps,
            total_obs=None,  # disable subsample scaling for minibatch SVI
            coef_scale=coef_scale,
        )

    if init_loc_values is not None:
        # Convert numpy arrays to jnp arrays
        values_jnp = {k: jnp.asarray(v) for k, v in init_loc_values.items()}
        init_loc_fn = init_to_value(values=values_jnp)
        guide = autoguide.AutoDelta(model_fn, init_loc_fn=init_loc_fn)
    else:
        guide = autoguide.AutoDelta(model_fn)
    optimizer = Adam(learning_rate)
    loss = Trace_ELBO(num_particles=10)         

    svi = SVI(
        model_fn,
        guide,
        optimizer,
        loss=loss,
    )

    if rng_key is None:
        rng_key = random.PRNGKey(0)

    init_batch_size = min(1024, total_obs)
    svi_state = svi.init(
        rng_key,
        arrays.features[:init_batch_size],
        makes=arrays.makes[:init_batch_size],
        attempts=arrays.attempts[:init_batch_size],
        player_idx=arrays.player_idx[:init_batch_size],
        team_season_idx=arrays.team_season_idx[:init_batch_size],
        opp_team_season_idx=arrays.opp_team_season_idx[:init_batch_size],
    )

    elbo_history: list[float] = []

    batch_iter = iter(batch_iter)

    for step in range(num_steps):
        try:
            batch = next(batch_iter)
        except StopIteration:
            print("WARNING: batch_iter exhausted before num_steps; stopping early.")
            break

        svi_state, loss = svi.update(
            svi_state,
            batch["features"],
            makes=batch["makes"],
            attempts=batch["attempts"],
            player_idx=batch["player_idx"],
            team_season_idx=batch["team_season_idx"],
            opp_team_season_idx=batch["opp_team_season_idx"],
        )
        loss_value = float(loss)
        batch_n = batch["features"].shape[0]
        loss_per_obs = loss_value / batch_n  # ELBO per obs for this minibatch

        elbo_history.append(loss_value)

        if (step + 1) % 100 == 0:
            print(
                f"step {step+1}, "
                f"loss={loss_value:.3f}, "
                f"ELBO={-loss_value:.3f}, "
                f"loss_per_obs={loss_per_obs:.6f}"
            )


    params = svi.get_params(svi_state)
    svi._last_state = svi_state  # type: ignore[attr-defined]
    svi._last_params = params  # type: ignore[attr-defined]
    svi._elbo_history = elbo_history  # optional: inspect in tests
    return svi


def make_minibatch_iterator(rng_key, arrays: HierarchicalPctArrays, batch_size: int):
    """Yield infinite shuffled minibatches for SVI training."""

    if random is None:  # pragma: no cover - import guard
        raise ImportError("JAX random is required for minibatch iterator")

    n_obs = len(arrays.makes)
    while True:
        rng_key, subkey = random.split(rng_key)
        perm = random.permutation(subkey, n_obs)
        for start in range(0, n_obs, batch_size):
            batch_idx = perm[start : start + batch_size]
            yield {
                "features": arrays.features[batch_idx],
                "makes": arrays.makes[batch_idx],
                "attempts": arrays.attempts[batch_idx],
                "player_idx": arrays.player_idx[batch_idx],
                "team_season_idx": arrays.team_season_idx[batch_idx],
                "opp_team_season_idx": arrays.opp_team_season_idx[batch_idx],
            }