Help with EIG for BNN model trained with SVI/autoguide

Hello,
I’m a biological engineering PhD student and am very excited about using Pyro, especially for Bayesian experimental design. I’m more of an experimentalist so preemptive apologies if the following is trivial, but I’ve spent quite a while trying to figure this out, and figured it wouldn’t hurt to ask.

I’ve trained a BNN using SVI with a multivariate normal autoguide (abbreviated code below).

I’m now trying to calculate the EIG towards designing new experiments, but am getting a shape error with the nmc_eig function. Not sure if my issue is how to pass the parameters contained in the layer/bias objects via the “target_sites” argument, I’m just guessing below based on what they seem to be called when I print the param_store.

Thanks very much!
Bryce

import torch
from torch import nn

import pyro
from pyro.nn import PyroSample
from pyro.nn import PyroModule
import pyro.distributions as dist
from pyro.infer.autoguide import AutoMultivariateNormal
from pyro.infer import SVI, Trace_ELBO
from pyro.contrib.oed.eig import nmc_eig, posterior_eig

#%% Define model

class BNN_shallow(PyroModule):
    def __init__(self, in_features, out_features, hidden_layer_dimension):
        super().__init__()
        
        self.activation = nn.Tanh()
        self.layer1 = PyroModule[nn.Linear](in_features, hidden_layer_dimension)
        self.layer2 = PyroModule[nn.Linear](hidden_layer_dimension, out_features)
        
        prior_scale = 10**-1.5 # Found this using nested cross validation on the full dataset...I know, very un-Bayesian :/
        
        self.layer1.weight = PyroSample(dist.Normal(0., prior_scale).expand([hidden_layer_dimension,in_features]).to_event(2)) # the expand and .to_event dicate the number of dimensions for the sampling, i think to_event might specify number of independent sampling dimenstions but i forget
        self.layer1.bias = PyroSample(dist.Normal(0., prior_scale).expand([hidden_layer_dimension]).to_event(1))
        self.layer2.weight = PyroSample(dist.Normal(0., prior_scale).expand([out_features,hidden_layer_dimension]).to_event(2))
        self.layer2.bias = PyroSample(dist.Normal(0., prior_scale).expand([out_features]).to_event(1))

    def forward(self,x,y=None):

        x = self.activation(self.layer1(x))
        mu = self.layer2(x).squeeze()
                
        sigmaEmperical = torch.tensor([0.031, 0.037, 0.028]) # standard deviation for each species, was estimated from experimental replicates
        
        with pyro.plate('data',x.shape[0],dim=-2): 
            obs = pyro.sample('obs',dist.Normal(mu, sigmaEmperical),obs=y)

        return mu

# 1/8th of my data for brevity:   
# partial experimental design - presence/absence of 8 different resource variables (columns) in the culture media, rows are conditions
predictorsTruncated = torch.tensor([[0,0,0,0,0,0,0,1],[0,0,0,0,0,1,1,0],[0,0,0,0,1,0,1,0],[0,0,0,0,1,1,0,1],
                 [0,0,0,1,0,0,1,0],[0,0,0,1,0,1,0,1],[0	,0,0,1,0,1,0,1],[0,0,0,1,0,1,0,1]],dtype=torch.float)

# partial response matrix - growth of 3-species (columns) microbial community measured in each resource combination condition (rows)
responsesTruncated = torch.tensor([[0.06,0.11,0.14],[0.04,0.04,0.14],[0.08,0.01,0.11],[0.13,0.08,0.06],
                              [0.18,0.03,0.18],[0.67,0.15,0.05],[0.67,	0.15,0.05],[0.47,0.67,0.1]],dtype=torch.float)

# number of inputs and outputs
in_dim = predictorsTruncated.shape[1]
out_dim = responsesTruncated.shape[1]

model = BNN_shallow(in_dim, out_dim, 4) # intialize model
guide = AutoMultivariateNormal(model) # intiliaze guide
adam = pyro.optim.Adam({"lr": 0.01}) # initialize optimizer
svi = SVI(model, guide, adam, loss = Trace_ELBO()) # initialize stochastic variational inference

# train model on truncated data
pyro.clear_param_store()
num_iterations = 700
for j in range(num_iterations):
    loss = svi.step(predictorsTruncated,responsesTruncated)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j+1, loss / len(responsesTruncated)))    
        
# first two rows are conditions in the truncated data above, latter is another condition from the full factorial (eig should be higher as a sanity check?)
candidate_designs = torch.tensor([[0,0,0,0,0,0,0,1],[0,0,0,0,0,1,1,0],[0,1,0,0,0,0,1,1]],dtype=torch.float)

# Tried this and a few similar things without success 
eig = nmc_eig(model, candidate_designs, observation_labels=["obs"], target_labels=['AutoMultivariateNormal'], N=2500, M=50)```