Hi all. I’m building a latent variable model. I have some part of the data (call them X) that I consider to be the causes of the latent state z and some part of the data (call them y) that I consider to be expressions (omissions) of that latent space. So the decoder neural net takes the latent state z and returns the loc and scale for the numerical omissions and the loc for the binary omissions (P(y|z)). The encoder neural net takes the causes X of the latent state and returns the loc and scale for the latent state z (P(z|x)). Bellow is the code for the encoder, decoder, model and guide. All this is influenced by the VAE example (Variational Autoencoders — Pyro Tutorials 1.8.4 documentation) and I train it just like in the example.
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim, num_obs_dim, cat_obs_dim):
super().__init__()
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, num_obs_dim)
self.fc22 = nn.Linear(hidden_dim, num_obs_dim)
self.fc2cat = nn.Linear(hidden_dim, cat_obs_dim)
self.softplus = nn.Softplus()
self.sigmoid = nn.Sigmoid()
def forward(self, z):
hidden = self.softplus(self.fc1(z))
loc_cat = self.sigmoid(self.fc2cat(hidden))
loc_num = self.sigmoid(self.fc21(hidden))
scale_num = torch.exp(self.fc22(hidden))
return loc_cat, loc_num, scale_num
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_dim, obs_dim):
super().__init__()
self.fc1 = nn.Linear(obs_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, z_dim)
self.fc22 = nn.Linear(hidden_dim, z_dim)
self.softplus = nn.Softplus()
def forward(self, x):
hidden = self.softplus(self.fc1(x))
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
def model(data):
z_dim=1
y = data[:,-3:]
X = data[:,:-3]
pyro.module("decoder", decoder)
with pyro.plate("data", data.shape[0]):
z_loc = X.new_zeros(torch.Size((X.shape[0], z_dim)))
z_scale = X.new_ones(torch.Size((X.shape[0], z_dim)))
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
loc_cat, loc_num, scale_num = decoder.forward(z)
cat_obs = pyro.sample("cat_obs", dist.Bernoulli(loc_cat).to_event(1), obs=y[:,-1])
num_obs = pyro.sample("num_obs", dist.LogNormal(loc_num, scale_num).to_event(1), obs=y[:,:-1])
return z, cat_obs, num_obs
def guide(data):
y = data[:,-3:]
X = data[:,:-3]
pyro.module("encoder", encoder)
with pyro.plate("data", X.shape[0]):
z_loc, z_scale = encoder(X)
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
What I find strange is that when I run the predictive for a test set there is no variability in the predictions for the observed variables y and they totally agree with the ground truth, while there is variability for the latent z. Bellow is how I run the predictive. Am I misunderstanding something? Is there a bug in my code? Any help would be much appreciated.
predictive = pyro.infer.Predictive(model, guide=guide, num_samples=800)
svi_samples = predictive(test_set.dataset[test_set.indices])