I am working on Dsprites Dataset and using SVI inference to do counterfactual evaluations. Importance algorithm and MCMC didn’t work for us and so we decided to make this an optimization problem by using SVI instead. Here is how my VAE looks like:
class Encoder(nn.Module):
def __init__(self, image_dim, label_dim, z_dim):
super(Encoder, self).__init__()
self.image_dim = image_dim
self.label_dim = label_dim
self.z_dim = z_dim
# setup the three linear transformations used
self.fc1 = nn.Linear(self.image_dim+self.label_dim, 1000)
self.fc2 = nn.Linear(1000, 1000)
self.fc31 = nn.Linear(1000, z_dim) # mu values
self.fc32 = nn.Linear(1000, z_dim) # sigma values
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, xs, ys):
# define the forward computation on the image xs and label ys
# first shape the mini-batch to have pixels in the rightmost dimension
xs = xs.reshape(-1, self.image_dim)
#now concatenate the image and label
inputs = torch.cat((xs,ys), -1)
# then compute the hidden units
hidden1 = self.softplus(self.fc1(inputs))
hidden2 = self.softplus(self.fc2(hidden1))
# then return a mean vector and a (positive) square root covariance
# each of size batch_size x z_dim
z_loc = self.fc31(hidden2)
z_scale = torch.exp(self.fc32(hidden2))
return z_loc, z_scale
class Decoder(nn.Module):
def __init__(self, image_dim, label_dim, z_dim):
super(Decoder, self).__init__()
# setup the two linear transformations used
hidden_dim = 1000
self.fc1 = nn.Linear(z_dim+label_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, image_dim)
# setup the non-linearities
self.softplus = nn.Softplus()
self.sigmoid = nn.Sigmoid()
def forward(self, zs, ys):
# define the forward computation on the latent z and label y
# first concatenate z and y
inputs = torch.cat((zs, ys),-1)
# then compute the hidden units
hidden1 = self.softplus(self.fc1(inputs))
hidden2 = self.softplus(self.fc2(hidden1))
hidden3 = self.softplus(self.fc3(hidden2))
# return the parameter for the output Bernoulli
# each is of size batch_size x 784
loc_img = self.sigmoid(self.fc4(hidden3))
return loc_img
class CVAE(nn.Module):
def __init__(self, config_enum=None, use_cuda=False, aux_loss_multiplier=None):
super(CVAE, self).__init__()
self.image_dim = 64**2
self.label_shape = np.array((1,3,6,40,32,32))
self.label_names = np.array(('color', 'shape', 'scale', 'orientation', 'posX', 'posY'))
self.label_dim = np.sum(self.label_shape)
self.z_dim = 50
self.use_cuda = use_cuda
# define and instantiate the neural networks representing
# the paramters of various distributions in the model
self.setup_networks()
def setup_networks(self):
self.encoder = Encoder(self.image_dim, self.label_dim, self.z_dim)
self.decoder = Decoder(self.image_dim, self.label_dim, self.z_dim)
# using GPUs for faster training of the networks
if self.use_cuda:
self.cuda()
def model(self, xs, ys):
"""
The model corresponds to the following generative process:
p(z) = normal(0,I) # dsprites label (latent)
p(x|y,z) = bernoulli(loc(y,z)) # an image
loc is given by a neural network `decoder`
:param xs: a batch of scaled vectors of pixels from an image
:param ys: a batch of the class labels i.e.
the digit corresponding to the image(s)
:return: None
"""
# register this pytorch module and all of its sub-modules with pyro
pyro.module("cvae", self)
batch_size = xs.size(0)
options = dict(dtype=xs.dtype, device=xs.device)
with pyro.plate("data"):
prior_loc = torch.zeros(batch_size, self.z_dim, **options)
prior_scale = torch.ones(batch_size, self.z_dim, **options)
zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))
# if the label y (which digit to write) is supervised, sample from the
# constant prior, otherwise, observe the value (i.e. score it against the constant prior)
loc = self.decoder.forward(zs, self.remap_y(ys))
pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)
# return the loc so we can visualize it later
return loc
def guide(self, xs, ys):
"""
The guide corresponds to the following:
q(z|x,y) = normal(loc(x,y),scale(x,y)) # infer latent class from an image and the label
loc, scale are given by a neural network `encoder`
:param xs: a batch of scaled vectors of pixels from an image
:return: None
"""
# inform Pyro that the variables in the batch of xs are conditionally independent
with pyro.plate("data"):
# sample (and score) the latent handwriting-style with the variational
# distribution q(z|x) = normal(loc(x),scale(x))
loc, scale = self.encoder.forward(xs, self.remap_y(ys))
pyro.sample("z", dist.Normal(loc, scale).to_event(1))
def remap_y(self, ys):
new_ys = []
options = dict(dtype=ys.dtype, device=ys.device)
for i, label_length in enumerate(self.label_shape):
prior = torch.ones(ys.size(0), label_length, **options) / (1.0 * label_length)
new_ys.append(pyro.sample("y_%s" % self.label_names[i], dist.OneHotCategorical(prior),
obs=torch.nn.functional.one_hot(ys[:,i].to(torch.int64), int(label_length))))
new_ys = torch.cat(new_ys, -1)
return new_ys.to(torch.float32)
def reconstruct_image(self, xs, ys):
# backward
sim_z_loc, sim_z_scale = self.encoder.forward(xs, self.remap_y(ys))
zs = dist.Normal(sim_z_loc, sim_z_scale).to_event(1).sample()
# forward
loc = self.decoder.forward(zs, self.remap_y(ys))
return dist.Bernoulli(loc).to_event(1).sample()
This is how my Structural Causal Model looks like:
class SCM():
def __init__(self, vae, mu, sigma):
self.vae = vae
self.image_dim = vae.image_dim
self.z_dim = vae.z_dim
mu = mu.cpu()
sigma = sigma.cpu()
# these are used for f_X
self.label_dims = vae.label_shape
def f_X(Y, Z, N):
zs = Z.cuda()
# convert the labels to one hot
ys = [torch.tensor([0])]
ys.append(torch.nn.functional.one_hot(torch.tensor(Y[0]), int(self.label_dims[1])))
ys.append(torch.nn.functional.one_hot(torch.tensor(Y[1]), int(self.label_dims[2])))
ys.append(torch.nn.functional.one_hot(torch.tensor(Y[2]), int(self.label_dims[3])))
ys.append(torch.nn.functional.one_hot(torch.tensor(Y[3]), int(self.label_dims[4])))
ys.append(torch.nn.functional.one_hot(torch.tensor(Y[4]), int(self.label_dims[5])))
ys = torch.cat(ys).to(torch.float32).reshape(1,-1).cuda()
p = vae.decoder.forward(zs, ys)
return (N < p.cpu()).type(torch.float)
def f_Y(N):
m = torch.distributions.gumbel.Gumbel(torch.zeros(N.size(0)), torch.ones(N.size(0)))
return torch.argmax(torch.add(torch.log(N), m.sample())).item()
def f_Z(N):
return N * sigma + mu
def model(noise):
N_X = pyro.sample( 'N_X', noise['N_X'] )
# There are 5 Y variables and they will be
# denoted using the index in the sequence
# that they are stored in as vae.label_names:
# ['shape', 'scale', 'orientation', 'posX', 'posY']
N_Y_1 = pyro.sample( 'N_Y_1', noise['N_Y_1'] )
N_Y_2 = pyro.sample( 'N_Y_2', noise['N_Y_2'] )
N_Y_3 = pyro.sample( 'N_Y_3', noise['N_Y_3'] )
N_Y_4 = pyro.sample( 'N_Y_4', noise['N_Y_4'] )
N_Y_5 = pyro.sample( 'N_Y_5', noise['N_Y_5'] )
N_Z = pyro.sample( 'N_Z', noise['N_Z'] )
Z = pyro.sample('Z', dist.Normal( f_Z( N_Z ), 1e-1) )
Y_1_mu = f_Y(N_Y_1)
Y_2_mu = f_Y(N_Y_2)
Y_3_mu = f_Y(N_Y_3)
Y_4_mu = f_Y(N_Y_4)
Y_5_mu = f_Y(N_Y_5)
Y_1 = pyro.sample('Y_1', dist.Normal( Y_1_mu, 1e-1) )
Y_2 = pyro.sample('Y_2', dist.Normal( Y_2_mu, 1e-1) )
Y_3 = pyro.sample('Y_3', dist.Normal( Y_3_mu, 1e-1) )
Y_4 = pyro.sample('Y_4', dist.Normal( Y_4_mu, 1e-1) )
Y_5 = pyro.sample('Y_5', dist.Normal( Y_5_mu, 1e-1) )
Y_mu = (Y_1_mu, Y_2_mu, Y_3_mu, Y_4_mu, Y_5_mu)
X = pyro.sample('X', dist.Normal( f_X( Y_mu, Z, N_X ), 1e-1) )
noise_samples = N_X, (N_Y_1, N_Y_2, N_Y_3, N_Y_4, N_Y_5), N_Z
variable_samples = X, (Y_1, Y_2, Y_3, Y_4, Y_5), Z
return variable_samples, noise_samples
self.model = model
self.init_noise = {
'N_X' : dist.Uniform(torch.zeros(vae.image_dim), torch.ones(vae.image_dim)),
'N_Z' : dist.Normal(torch.zeros(vae.z_dim), torch.ones(vae.z_dim)),
'N_Y_1' : dist.Uniform(torch.zeros(label_dims[1]),torch.ones(self.label_dims[1])),
'N_Y_2' : dist.Uniform(torch.zeros(label_dims[2]),torch.ones(self.label_dims[2])),
'N_Y_3' : dist.Uniform(torch.zeros(label_dims[3]),torch.ones(self.label_dims[3])),
'N_Y_4' : dist.Uniform(torch.zeros(label_dims[4]),torch.ones(self.label_dims[4])),
'N_Y_5' : dist.Uniform(torch.zeros(label_dims[5]),torch.ones(self.label_dims[5]))
}
def update_noise_svi(self, obs_data):
# assume all noise variables are normal distributions
# use svi to find out the mu, sigma of the distributions
# for the condition outlined in obs_data
def guide(noise):
# create params with constraints
mu = {'N_X': pyro.param('N_X_mu', 0.5*torch.ones(self.image_dim),
constraint = constraints.interval(0., 1.)),
'N_Z': pyro.param('N_Z_mu', torch.zeros(self.z_dim),
constraint = constraints.interval(-3., 3.)),
'N_Y_1': pyro.param('N_Y_1_mu', 0.5*torch.ones(self.label_dims[1]),
constraint = constraints.interval(0., 1.)),
'N_Y_2': pyro.param('N_Y_2_mu', 0.5*torch.ones(self.label_dims[2]),
constraint = constraints.interval(0., 1.)),
'N_Y_3': pyro.param('N_Y_3_mu', 0.5*torch.ones(self.label_dims[3]),
constraint = constraints.interval(0., 1.)),
'N_Y_4': pyro.param('N_Y_4_mu', 0.5*torch.ones(self.label_dims[4]),
constraint = constraints.interval(0., 1.)),
'N_Y_5': pyro.param('N_Y_5_mu', 0.5*torch.ones(self.label_dims[5]),
constraint = constraints.interval(0., 1.))
}
sigma = {'N_X': pyro.param('N_X_sigma', 0.1*torch.ones(self.image_dim),
constraint = constraints.interval(0.0001, 0.5)),
'N_Z': pyro.param('N_Z_sigma', torch.ones(self.z_dim),
constraint = constraints.interval(0.0001, 3.)),
'N_Y_1': pyro.param('N_Y_1_sigma', 0.1*torch.ones(self.label_dims[1]),
constraint = constraints.interval(0.0001, 0.5)),
'N_Y_2': pyro.param('N_Y_2_sigma', 0.1*torch.ones(self.label_dims[2]),
constraint = constraints.interval(0.0001, 0.5)),
'N_Y_3': pyro.param('N_Y_3_sigma', 0.1*torch.ones(self.label_dims[3]),
constraint = constraints.interval(0.0001, 0.5)),
'N_Y_4': pyro.param('N_Y_4_sigma', 0.1*torch.ones(self.label_dims[4]),
constraint = constraints.interval(0.0001, 0.5)),
'N_Y_5': pyro.param('N_Y_5_sigma', 0.1*torch.ones(self.label_dims[5]),
constraint = constraints.interval(0.0001, 0.5))
}
for noise_term in noise.keys():
pyro.sample(noise_term, dist.Normal(mu[noise_term], sigma[noise_term]))
obs_model = pyro.condition(self.model, obs_data)
pyro.clear_param_store()
svi = SVI(
model= obs_model,
guide= guide,
optim= Adam({"lr": 1e-3}),
loss=Trace_ELBO()
)
num_steps = 1000
samples = defaultdict(list)
for t in range(num_steps):
svi.step(self.init_noise)
# now determine new noise variables
for noise in initial_noise.keys():
mu = '{}_mu'.format(noise)
sigma = '{}_sigma'.format(noise)
samples[mu].append(pyro.param(mu).item())
samples[sigma].append(pyro.param(sigma).item())
means = {k: torch.mean(torch.cat(v, 0),0) for k, v in samples.items()}
updated_noise = {
'N_X': dist.Normal(means['N_X_mu'], means['N_X_sigma']),
'N_Z': dist.Normal(means['N_Z_mu'], means['N_Z_sigma']),
'N_Y_1': dist.Normal(means['N_Y_1_mu'], means['N_Y_1_sigma']),
'N_Y_2': dist.Normal(means['N_Y_2_mu'], means['N_Y_2_sigma']),
'N_Y_3': dist.Normal(means['N_Y_3_mu'], means['N_Y_3_sigma']),
'N_Y_4': dist.Normal(means['N_Y_4_mu'], means['N_Y_4_sigma']),
'N_Y_5': dist.Normal(means['N_Y_5_mu'], means['N_Y_5_sigma']),
}
return updated_noise
def __call__(self):
return self.model(self.init_noise)
But when I run this,
x, y = get_specific_data(cuda=True)
mu, sigma = vae.encoder.forward(x,vae.remap_y(y))
scm = SCM(vae, mu.cpu(), sigma.cpu())
cond_data = {}
for i in range(1, 6):
cond_data["Y_{}".format(i)] = torch.tensor(y[0,i].cpu()).to(torch.float32)
cond_noise = scm.update_noise_svi(cond_data)
I get the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-278-e22c3f786618> in <module>
4
5
----> 6 cond_noise = scm.update_noise_svi(cond_data)
<ipython-input-276-8176258ce037> in update_noise_svi(self, obs_data)
128 samples = defaultdict(list)
129 for t in range(num_steps):
--> 130 svi.step(self.init_noise)
131
132 # now determine new noise variables
~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
97 # get loss and compute gradients
98 with poutine.trace(param_only=True) as param_capture:
---> 99 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
100
101 params = set(site["value"].unconstrained()
~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
123 loss = 0.0
124 # grab a trace from the generator
--> 125 for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
126 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
127 loss += loss_particle / self.num_particles
~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, *args, **kwargs)
166 else:
167 for i in range(self.num_particles):
--> 168 yield self._get_trace(model, guide, *args, **kwargs)
~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, *args, **kwargs)
50 """
51 model_trace, guide_trace = get_importance_trace(
---> 52 "flat", self.max_plate_nesting, model, guide, *args, **kwargs)
53 if is_validation_enabled():
54 check_if_enumerated(guide_trace)
~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, *args, **kwargs)
54 for site in model_trace.nodes.values():
55 if site["type"] == "sample":
---> 56 check_site_shape(site, max_plate_nesting)
57 for site in guide_trace.nodes.values():
58 if site["type"] == "sample":
~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
260 '- enclose the batched tensor in a with plate(...): context',
261 '- .to_event(...) the distribution being sampled',
--> 262 '- .permute() data dimensions']))
263
264 # Check parallel dimensions on the left of max_plate_nesting.
ValueError: at site "N_X", invalid log_prob shape
Expected [], actual [4096]
Try one of the following fixes:
- enclose the batched tensor in a with plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions
Please let me know what the issue is as I am new to Pyro and I am not sure what is causing this issue.
Also, this is with regards to the project I am working for my course on Causal Inference in Machine Learning. Here is the link to course: GitHub - altdeep/causalML: The open source repository for the Causal Modeling in Machine Learning Workshop at Altdeep.ai @ www.altdeep.ai/courses/causalML
Thank you for the support!