I am trying to reproduce the results of causal inference from the Chapter 19 of “Regression And Other Stories” R-demo.
The simplest way (which is done in the book) is for loops for each grade and get the effect values and I did it here simple_difference_estimate
.
I want to use plate
for grades instead of four separate inference and as you can see simple_difference_per_grade
I stuck here.
My question is how I can use ind
from plate
context to get the correct inputs for each grade. Or if I want to say it differently how can I mask z
and y
given ind
.
import pandas as pd
from dataclasses import dataclass
import jax
import numpyro
import numpyro.distributions as dist
@dataclass
class params:
data_path: str = "https://raw.githubusercontent.com/avehtari/ROS-Examples/master/ElectricCompany/data/electric.csv"
seed: int = 12324
num_warmup: int = 1000
num_samples: int = 1000
num_chains: int = 1
df = pd.read_csv(params.data_path, index_col="Unnamed: 0")
def simple_difference_estimate(
z: jax.Array,
y: jax.Array = None,
):
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
theta = numpyro.sample("theta", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal())
numpyro.sample("scores", dist.Normal(loc=alpha + theta * z, scale=sigma), obs=y)
def simple_difference_per_grade(
z: jax.Array,
y: jax.Array,
):
with numpyro.plate("grade_plate", 4) as ind:
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
theta = numpyro.sample("theta", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal())
# loc = alpha + theta * z[ind]
# numpyro.sample("scores", dist.Normal(loc, sigma), obs=y[ind])
# Inference with for loop for each grade separately:
for g in [1, 2, 3, 4]:
rng_key = jax.random.PRNGKey(params.seed)
kernel = numpyro.infer.NUTS(simple_difference_estimate)
mcmc = numpyro.infer.MCMC(
kernel,
num_warmup=params.num_warmup,
num_samples=params.num_samples,
num_chains=params.num_chains,
)
mcmc.run(
rng_key=rng_key,
y=df[df["grade"] == g]["post_test"].values,
z=df[df["grade"] == g]["treatment"].values,
)
mcmc.print_summary()