Use plate index to mask input data

I am trying to reproduce the results of causal inference from the Chapter 19 of “Regression And Other Stories” R-demo.

The simplest way (which is done in the book) is for loops for each grade and get the effect values and I did it here simple_difference_estimate.

I want to use plate for grades instead of four separate inference and as you can see simple_difference_per_grade I stuck here.

My question is how I can use ind from plate context to get the correct inputs for each grade. Or if I want to say it differently how can I mask z and y given ind.

import pandas as pd
from dataclasses import dataclass
import jax
import numpyro
import numpyro.distributions as dist


@dataclass
class params:
    data_path: str = "https://raw.githubusercontent.com/avehtari/ROS-Examples/master/ElectricCompany/data/electric.csv"
    seed: int = 12324
    num_warmup: int = 1000
    num_samples: int = 1000
    num_chains: int = 1


df = pd.read_csv(params.data_path, index_col="Unnamed: 0")


def simple_difference_estimate(
    z: jax.Array,
    y: jax.Array = None,
):
    alpha = numpyro.sample("alpha", dist.Normal(0, 1))
    theta = numpyro.sample("theta", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfNormal())
    numpyro.sample("scores", dist.Normal(loc=alpha + theta * z, scale=sigma), obs=y)


def simple_difference_per_grade(
    z: jax.Array,
    y: jax.Array,
):
    with numpyro.plate("grade_plate", 4) as ind:
        alpha = numpyro.sample("alpha", dist.Normal(0, 1))
        theta = numpyro.sample("theta", dist.Normal(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfNormal())
        # loc = alpha + theta * z[ind]
        # numpyro.sample("scores", dist.Normal(loc, sigma), obs=y[ind])


# Inference with for loop for each grade separately:
for g in [1, 2, 3, 4]:
    rng_key = jax.random.PRNGKey(params.seed)
    kernel = numpyro.infer.NUTS(simple_difference_estimate)
    mcmc = numpyro.infer.MCMC(
        kernel,
        num_warmup=params.num_warmup,
        num_samples=params.num_samples,
        num_chains=params.num_chains,
    )
    mcmc.run(
        rng_key=rng_key,
        y=df[df["grade"] == g]["post_test"].values,
        z=df[df["grade"] == g]["treatment"].values,
    )
    mcmc.print_summary()

For those who are interested, I resolved my confusion, by removing the observation from the plate. Here is the inference code for causal inference example of Chapter 19 of the “Regression and Other Stories” in numpyro.

import pandas as pd
from dataclasses import dataclass
import numpy as np
import jax
import numpyro
import numpyro.distributions as dist


@dataclass
class params:
    data_path: str = "https://raw.githubusercontent.com/avehtari/ROS-Examples/master/ElectricCompany/data/electric.csv"
    seed: int = 12324
    num_warmup: int = 1000
    num_samples: int = 5000
    num_chains: int = 1


def read_data(params) -> pd.DataFrame:
    df = pd.read_csv(params.data_path, index_col="Unnamed: 0")
    return df


def simple_difference_estimate(**kwargs):
    y = kwargs["y"]
    z = kwargs["z"]
    alpha = numpyro.sample("alpha", dist.Normal(0, 1))
    theta = numpyro.sample("theta", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfNormal())
    loc = numpyro.deterministic("loc", alpha + theta * z)
    numpyro.sample("scores", dist.Normal(loc=loc, scale=sigma), obs=y)


def simple_difference_per_grade(**kwargs):
    y = kwargs["y"]
    z = kwargs["z"]
    grade = kwargs["grade"]
    n_grade = len(np.unique(grade))
    with numpyro.plate("grade_plate", n_grade):
        alpha = numpyro.sample("alpha", dist.Normal(0, 1))
        theta = numpyro.sample("theta", dist.Normal(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfNormal())

    loc = alpha[grade - 1] + theta[grade - 1] * z
    numpyro.sample("scores", dist.Normal(loc, sigma[grade - 1]), obs=y)


def inference(rng_key, params, model, y, z, grade):
    kernel = numpyro.infer.NUTS(model)
    mcmc = numpyro.infer.MCMC(
        kernel,
        num_warmup=params.num_warmup,
        num_samples=params.num_samples,
        num_chains=params.num_chains,
    )

    mcmc.run(rng_key=rng_key, y=y, z=z, grade=grade)
    mcmc.print_summary()


def main(params):
    rng_key = jax.random.PRNGKey(params.seed)
    _, *sub_key = jax.random.split(rng_key, num=10)
    # Read the data
    df = read_data(params)
    print(df.head())

    sep = 80 * "-"
    # Inference Using for loop for each grade
    for g in df.grade.unique():
        print(f"\nFor grade {g}:")
        inference(
            params=params,
            model=simple_difference_estimate,
            rng_key=sub_key[g],
            z=df[df.grade == g]["treatment"].values,
            y=df[df.grade == g]["post_test"].values,
            grade=g,
        )
        del sub_key[g]
        print(sep)

    # Inference Using the plate for each grade
    print(sep)
    print("\nInference using plate")
    inference(
        params=params,
        model=simple_difference_per_grade,
        rng_key=sub_key[-1],
        z=df["treatment"].values,
        y=df["post_test"].values,
        grade=df["grade"].values,
    )


if __name__ == "__main__":
    main(params)