Do operation returns nothing (action-abduction-inference)

Hi, I’m relatively new to Pyro and therefore have been relying heavily on tutorials, namely this one covering normalizing flows for counterfactual inference. It performs counterfactual inference which I am trying to implement.

When trying to generate counterfactuals, my code fails at the do step of setting a binary variable to is opposite and my counterfactual model has an empty trace. A minimum working example can be seen below:

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

import pandas as pd
import numpy as np
import torch

def ProposedModelSmall(data, gamma_shift=0., C=0.):
    gamma_shift = torch.tensor(gamma_shift)
    C = torch.tensor(C)

    lambda1 = torch.tensor(8.679151808881061e-24)
    lambda2 = torch.tensor(-0.15230709314346313)
    lambda3 = torch.tensor(-0.8194252848625183)
    lambda4 = torch.tensor(0.849686861038208)
    
    with pyro.plate("data", data.size(0)):
        u1 = pyro.sample("u1", dist.Normal(torch.tensor(0.), torch.tensor(0.6)))
        u2 = pyro.sample("u2", dist.Normal(torch.tensor(0.), torch.tensor(0.8)))
        u3 = pyro.sample("u3", dist.Normal(torch.tensor(0.), torch.tensor(0.9)))
        
        a_prob = torch.sigmoid(gamma_shift*lambda1*C + u1)
        a_val = pyro.sample("a", dist.Bernoulli(probs=a_prob), obs=data[:,0])

        r_val = pyro.sample("r", dist.Normal(lambda2*a_val + u2, 1.), obs=data[:,1]) 

        y_prob = torch.sigmoid(lambda3*a_val + lambda4*r_val + u3)
        y_val = pyro.sample("y", dist.Bernoulli(probs=y_prob), obs=data[:,2])

# Sample data
samples = pd.DataFrame(columns=['A', 'R', 'Y'])
samples['A'] = [1., 0., 1.]
samples['R'] = [1.25, 0.87, 0.43]
samples['Y'] = [0., 1., 1.]

def infer_exogenous(obs, model):
    input_obs = torch.tensor(np.array([[obs[k] for k in ['a', 'r', 'y']]]))
    
    cond_sample = pyro.condition(lambda: model(input_obs), obs)
    cond_trace = pyro.poutine.trace(cond_sample).get_trace()
    
    exogenous = {k: cond_trace.nodes[k]['value'] for k in  ['u1', 'u2', 'u3']}
    return exogenous

def counterfactual(model, obs):
    # Infer state of world (ie learn noise)
    exogenous = infer_exogenous(obs, model)
        
    # Find counterfactual A value (a') for this sample
    input_obs = torch.tensor(np.array([[obs[k] for k in ['a', 'r', 'y']]]))
    cf_a = 0 if input_obs.numpy()[0][0] == 1. else 1.
    
    # Compute counterfactual sample
    state_of_world = poutine.condition(lambda: model(input_obs), data=exogenous) # Conditioned model works as intended
    cf = poutine.do(poutine.condition(lambda: model(input_obs), data=exogenous), data={'a': torch.tensor(cf_a)})() # Performing do operation on conditioned model fails
    
    return cf, cf_a, state_of_world

# Get row of data and learn state of world that produced it
obs = {k.lower(): torch.tensor(samples.iloc[0][k]) for k in ['A', 'R', 'Y']} 

# Infer noise/state of the world and generate counterfactual
cf, cf_a, state_of_world = counterfactual(ProposedModelSmall, obs)
print('\nCounterfactual', cf)
print('\nCounterfactual A=a_prime', cf_a)

What I have currently done as part of the action-abduction-inference algorithm:

  • Learned the lambda variables (using SVI and an autodelta guide) hence the obs statements in the model (however I hardcoded them above to keep the minimum example short)
  • Inferred the exogenous noise variables (the three u variables) for a given data point
  • Conditioned the model on the inferred exogenous noise (the model conditioned on exogenous noise has a nonempty trace that is in accordance with the conditioning)

However, the last part of the algorithm, the do operation, where I want to get the tuple of (A, R, Y) variables when do(A=a') is performed, is something I believe I am not doing correctly.

In short, what I would expect from the code above yet haven’t been able to fix:

  • The cf = pyro.poutine.do(...)() variable in the code to not be None
  • The trace, tr = poutine.trace(cf).get_trace() of the counterfactual model to have a nonempty node dict (tr.nodes != OrderedDict())

Thanks for taking the time to look this over!

Hi @zakhaal, your model function ProposedModelSmall does not return anything, so neither will the intervened version do(model). What value would you like cf to have?

I’m not sure what you mean by your second point, since it does not seem to correspond to anything in your example code. Can you elaborate? What exactly is being traced?

Also, condition specifies an inference problem but does not compute the solution to that problem, so if this is how you are computing values of u in your full example then it is almost certainly incorrect. I suggest reviewing the official introduction to Pyro, which may clear up some of these conceptual issues.

Hi, thanks so much for your reply! You are exactly right, I wasn’t returning any values from my model yet I expected something to be returned which doesn’t make sense. I was able to fix this after reviewing what exactly the do and condition functions actually do. This also cleared up my misconceptions regarding the trace as the do operator returns a (mutilated) model which I can input data into; tracing doesn’t make sense in that case.

Regarding your third point, you are correct that I am not solving the inference problem here. To be more clear about the setup of the problem for anyone interested in it, I’m assuming that the solution to the inference problem (which is learning the optimal values of the lambda variables) is already complete (in my case, I used SVI and an autodelta guide to learn these, and placed the learned values in a dictionary instead of hardcoding them in the model). In the updated code below, I properly represent the lambda variables as sample statements in the model, and when I perform conditioning to compute values of u, I also condition on the learned values of lambda in the counterfactual function so that the inference isn’t bunk and is based on parameters that have been learned.

In case a working example can benefit anyone in the future, the working code for what I wanted to do (action, abduction, and inference for twin world counterfactuals) can be seen below:

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

import pandas as pd
import numpy as np
import torch

def ProposedModelSmall(a=None, r=None, y=None, gamma_shift=0., C=0.):
    gamma_shift = torch.tensor(gamma_shift)
    C = torch.tensor(C)

    lambda1 = pyro.sample("lambda1", dist.Normal(0., 1.))
    lambda2 = pyro.sample("lambda2", dist.Normal(0., 1.))
    lambda3 = pyro.sample("lambda3", dist.Normal(0., 1.))
    lambda4 = pyro.sample("lambda4", dist.Normal(0., 1.))
    
    with pyro.plate("data", a.size(0)):
        u1 = pyro.sample("u1", dist.Normal(torch.tensor(0.), torch.tensor(0.6)))
        u2 = pyro.sample("u2", dist.Normal(torch.tensor(0.), torch.tensor(0.8)))
        u3 = pyro.sample("u3", dist.Normal(torch.tensor(0.), torch.tensor(0.9)))
        
        a_prob = torch.sigmoid(gamma_shift*lambda1*C + u1)
        a_val = pyro.sample("a", dist.Bernoulli(probs=a_prob), obs=a)

        r_val = pyro.sample("r", dist.Normal(lambda2*a_val + u2, 1.), obs=r) 

        y_prob = torch.sigmoid(lambda3*a_val + lambda4*r_val + u3)
        y_val = pyro.sample("y", dist.Bernoulli(probs=y_prob), obs=y)
        
        return a_val, r_val, y_val

# Sample data
samples = pd.DataFrame(columns=['A', 'R', 'Y'])
samples['A'] = [1., 0., 1.]
samples['R'] = [1.25, 0.87, 0.43]
samples['Y'] = [0., 1., 1.]

def infer_exogenous(obs, model):
    input_a = torch.tensor(np.array([obs['a']]))
    input_r = torch.tensor(np.array([obs['r']]))
    input_y = torch.tensor(np.array([obs['y']]))
    
    cond_sample = pyro.condition(lambda: model(input_a, input_r, input_y), obs)
    cond_trace = pyro.poutine.trace(cond_sample).get_trace()
    
    exogenous = {k: cond_trace.nodes[k]['value'] for k in  ['u1', 'u2', 'u3']}
    return exogenous


def counterfactual(model, obs, learned_params):
    # Infer state of world (ie learn noise)
    exogenous = infer_exogenous(obs, model)
    exogenous_and_learned = {**exogenous, **learned_params}
    
    # Find counterfactual A value (a') for this sample
    input_a = torch.tensor(np.array([obs['a']]))
    cf_a = 0 if input_a.numpy()[0] == 1. else 1.
    
    # Compute counterfactual sample
    cf_model = pyro.do(pyro.condition(model, data=exogenous_and_learned), data={'a': torch.tensor(cf_a)})
    
    return cf_model, cf_a

# Get sample
obs = {k.lower(): torch.tensor(samples.iloc[0][k]) for k in ['A', 'R', 'Y']} 

# Infer noise/state of the world and generate counterfactual
learned_params = {'lambda1': torch.tensor(8.68e-24), 'lambda2': torch.tensor(-0.15), 'lambda3': torch.tensor(-0.81), 'lambda4': torch.tensor(0.85)}
cf, cf_a = counterfactual(ProposedModelSmall, obs, learned_params)

input_a = torch.tensor(np.array([obs['a']]))
print(cf(input_a))  # Prints counterfactual sample 

print(obs) # Prints original sample for comparison

Hi, thanks for sharing the working code. I found it very useful for my application. Just a question. How do you think it would be if doing counterfactual on a continuous valued variables? In your case, R, for instance.

I’m glad you were able to find something useful in my code!

Can you elaborate a little more on your question? I’m interpreting it as what the difficulty associated with counterfactual evaluation is for continuous variables, and to answer, it’s definitely possible. Provided you have information about the distribution of the exogenous variables and samples from said distribution (and that you were able to adequately solve the inference problem if you are working on a problem setup similar to mine), you shouldn’t have any difficulty in generating continuous valued twin world counterfactuals for your application. In my problem setup the noise is on the individual level (different u values for each (A, R, T, Y) tuple), so generating counterfactuals at the individual level will be tough as I only have one data point to infer the exogenous noise from (so counterfactual predictions of R may be off of what the “true” counterfactual values, e.g. the case where I know all the exogenous terms for each data point, of R are).

Thank you very much for the response. What I meant is that how I can define a counterfactual (CF) query on continuous variable. For instance, from Pearl’s Causality book, if consider X as cause and Y as effect, both binary variables: PN = P(Y = y’, do(X = x’) | Y = y , X = x). And this is how you defined your CF queries on A variable taking 0 and 1.
Now imagine I have a variable like Rain having countinous values and an effect variable of wetness being discrete 0,1,2,3 based on the level of wetness. In reality, Rain = 10 mm/h and Wetness = 1. How would you then change the line

cf_a = 0 if input_a.numpy()[0] == 1. else 1.

if calculating cf_Rain. Would you then take any random Rain value10?
Hope I could elaburate on my question a bit more.

And second question: would that be possible giving a hint how to infer lambda (model) parameters from the data?

Sorry for newbie questions and thank you in advance.

Whatever counterfactual value you choose to set your continuous variable to depends on your application. For example, you could choose to see what counterfactuals would look like if rainfall was very small (Rain = 1 mm/h) if you want to compare wetness/other parameters derived from rainfall from an observed location to a location with minimal rainfall, it really depends on what your project is aiming to do.

As to the inferring of lambdas, I personally used an autodelta guide. The autoguides are simple to set up and use once you have your model defined and you can find an example of their usage here. Great questions, hope you’ll be able to get everything running quickly!

1 Like