Hi, I’m relatively new to Pyro and therefore have been relying heavily on tutorials, namely this one covering normalizing flows for counterfactual inference. It performs counterfactual inference which I am trying to implement.
When trying to generate counterfactuals, my code fails at the do
step of setting a binary variable to is opposite and my counterfactual model has an empty trace. A minimum working example can be seen below:
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import pandas as pd
import numpy as np
import torch
def ProposedModelSmall(data, gamma_shift=0., C=0.):
gamma_shift = torch.tensor(gamma_shift)
C = torch.tensor(C)
lambda1 = torch.tensor(8.679151808881061e-24)
lambda2 = torch.tensor(-0.15230709314346313)
lambda3 = torch.tensor(-0.8194252848625183)
lambda4 = torch.tensor(0.849686861038208)
with pyro.plate("data", data.size(0)):
u1 = pyro.sample("u1", dist.Normal(torch.tensor(0.), torch.tensor(0.6)))
u2 = pyro.sample("u2", dist.Normal(torch.tensor(0.), torch.tensor(0.8)))
u3 = pyro.sample("u3", dist.Normal(torch.tensor(0.), torch.tensor(0.9)))
a_prob = torch.sigmoid(gamma_shift*lambda1*C + u1)
a_val = pyro.sample("a", dist.Bernoulli(probs=a_prob), obs=data[:,0])
r_val = pyro.sample("r", dist.Normal(lambda2*a_val + u2, 1.), obs=data[:,1])
y_prob = torch.sigmoid(lambda3*a_val + lambda4*r_val + u3)
y_val = pyro.sample("y", dist.Bernoulli(probs=y_prob), obs=data[:,2])
# Sample data
samples = pd.DataFrame(columns=['A', 'R', 'Y'])
samples['A'] = [1., 0., 1.]
samples['R'] = [1.25, 0.87, 0.43]
samples['Y'] = [0., 1., 1.]
def infer_exogenous(obs, model):
input_obs = torch.tensor(np.array([[obs[k] for k in ['a', 'r', 'y']]]))
cond_sample = pyro.condition(lambda: model(input_obs), obs)
cond_trace = pyro.poutine.trace(cond_sample).get_trace()
exogenous = {k: cond_trace.nodes[k]['value'] for k in ['u1', 'u2', 'u3']}
return exogenous
def counterfactual(model, obs):
# Infer state of world (ie learn noise)
exogenous = infer_exogenous(obs, model)
# Find counterfactual A value (a') for this sample
input_obs = torch.tensor(np.array([[obs[k] for k in ['a', 'r', 'y']]]))
cf_a = 0 if input_obs.numpy()[0][0] == 1. else 1.
# Compute counterfactual sample
state_of_world = poutine.condition(lambda: model(input_obs), data=exogenous) # Conditioned model works as intended
cf = poutine.do(poutine.condition(lambda: model(input_obs), data=exogenous), data={'a': torch.tensor(cf_a)})() # Performing do operation on conditioned model fails
return cf, cf_a, state_of_world
# Get row of data and learn state of world that produced it
obs = {k.lower(): torch.tensor(samples.iloc[0][k]) for k in ['A', 'R', 'Y']}
# Infer noise/state of the world and generate counterfactual
cf, cf_a, state_of_world = counterfactual(ProposedModelSmall, obs)
print('\nCounterfactual', cf)
print('\nCounterfactual A=a_prime', cf_a)
What I have currently done as part of the action-abduction-inference algorithm:
- Learned the
lambda
variables (using SVI and an autodelta guide) hence the obs statements in the model (however I hardcoded them above to keep the minimum example short) - Inferred the exogenous noise variables (the three
u
variables) for a given data point - Conditioned the model on the inferred exogenous noise (the model conditioned on exogenous noise has a nonempty trace that is in accordance with the conditioning)
However, the last part of the algorithm, the do
operation, where I want to get the tuple of (A, R, Y)
variables when do(A=a')
is performed, is something I believe I am not doing correctly.
In short, what I would expect from the code above yet haven’t been able to fix:
- The
cf = pyro.poutine.do(...)()
variable in the code to not be None - The trace,
tr = poutine.trace(cf).get_trace()
of the counterfactual model to have a nonempty node dict (tr.nodes != OrderedDict()
)
Thanks for taking the time to look this over!