# Use plate index to mask input data

I am trying to reproduce the results of causal inference from the Chapter 19 of “Regression And Other Stories” R-demo.

The simplest way (which is done in the book) is for loops for each grade and get the effect values and I did it here `simple_difference_estimate`.

I want to use `plate` for grades instead of four separate inference and as you can see `simple_difference_per_grade` I stuck here.

My question is how I can use `ind` from `plate` context to get the correct inputs for each grade. Or if I want to say it differently how can I mask `z` and `y` given `ind`.

``````import pandas as pd
from dataclasses import dataclass
import jax
import numpyro
import numpyro.distributions as dist

@dataclass
class params:
data_path: str = "https://raw.githubusercontent.com/avehtari/ROS-Examples/master/ElectricCompany/data/electric.csv"
seed: int = 12324
num_warmup: int = 1000
num_samples: int = 1000
num_chains: int = 1

def simple_difference_estimate(
z: jax.Array,
y: jax.Array = None,
):
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
theta = numpyro.sample("theta", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal())
numpyro.sample("scores", dist.Normal(loc=alpha + theta * z, scale=sigma), obs=y)

z: jax.Array,
y: jax.Array,
):
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
theta = numpyro.sample("theta", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal())
# loc = alpha + theta * z[ind]
# numpyro.sample("scores", dist.Normal(loc, sigma), obs=y[ind])

# Inference with for loop for each grade separately:
for g in [1, 2, 3, 4]:
rng_key = jax.random.PRNGKey(params.seed)
kernel = numpyro.infer.NUTS(simple_difference_estimate)
mcmc = numpyro.infer.MCMC(
kernel,
num_warmup=params.num_warmup,
num_samples=params.num_samples,
num_chains=params.num_chains,
)
mcmc.run(
rng_key=rng_key,
)
mcmc.print_summary()
``````

For those who are interested, I resolved my confusion, by removing the observation from the `plate`. Here is the inference code for causal inference example of Chapter 19 of the “Regression and Other Stories” in `numpyro`.

``````import pandas as pd
from dataclasses import dataclass
import numpy as np
import jax
import numpyro
import numpyro.distributions as dist

@dataclass
class params:
data_path: str = "https://raw.githubusercontent.com/avehtari/ROS-Examples/master/ElectricCompany/data/electric.csv"
seed: int = 12324
num_warmup: int = 1000
num_samples: int = 5000
num_chains: int = 1

return df

def simple_difference_estimate(**kwargs):
y = kwargs["y"]
z = kwargs["z"]
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
theta = numpyro.sample("theta", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal())
loc = numpyro.deterministic("loc", alpha + theta * z)
numpyro.sample("scores", dist.Normal(loc=loc, scale=sigma), obs=y)

y = kwargs["y"]
z = kwargs["z"]
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
theta = numpyro.sample("theta", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal())

numpyro.sample("scores", dist.Normal(loc, sigma[grade - 1]), obs=y)

def inference(rng_key, params, model, y, z, grade):
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(
kernel,
num_warmup=params.num_warmup,
num_samples=params.num_samples,
num_chains=params.num_chains,
)

mcmc.print_summary()

def main(params):
rng_key = jax.random.PRNGKey(params.seed)
_, *sub_key = jax.random.split(rng_key, num=10)

sep = 80 * "-"
# Inference Using for loop for each grade
inference(
params=params,
model=simple_difference_estimate,
rng_key=sub_key[g],
)
del sub_key[g]
print(sep)

# Inference Using the plate for each grade
print(sep)
print("\nInference using plate")
inference(
params=params,