No variability at predictions of latent variable model

Hi all. I’m building a latent variable model. I have some part of the data (call them X) that I consider to be the causes of the latent state z and some part of the data (call them y) that I consider to be expressions (omissions) of that latent space. So the decoder neural net takes the latent state z and returns the loc and scale for the numerical omissions and the loc for the binary omissions (P(y|z)). The encoder neural net takes the causes X of the latent state and returns the loc and scale for the latent state z (P(z|x)). Bellow is the code for the encoder, decoder, model and guide. All this is influenced by the VAE example (Variational Autoencoders — Pyro Tutorials 1.8.3 documentation) and I train it just like in the example.

class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, num_obs_dim, cat_obs_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, num_obs_dim)
        self.fc22 = nn.Linear(hidden_dim, num_obs_dim)
        self.fc2cat = nn.Linear(hidden_dim, cat_obs_dim)
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        hidden = self.softplus(self.fc1(z))
        loc_cat = self.sigmoid(self.fc2cat(hidden))
        loc_num = self.sigmoid(self.fc21(hidden))
        scale_num = torch.exp(self.fc22(hidden))
        return loc_cat, loc_num, scale_num

class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, obs_dim):
        super().__init__()
        self.fc1 = nn.Linear(obs_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        self.softplus = nn.Softplus()

    def forward(self, x):
        hidden = self.softplus(self.fc1(x))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale
    

def model(data):
    z_dim=1
    y = data[:,-3:]
    X = data[:,:-3]
    pyro.module("decoder", decoder)
    with pyro.plate("data", data.shape[0]):
        z_loc = X.new_zeros(torch.Size((X.shape[0], z_dim)))
        z_scale = X.new_ones(torch.Size((X.shape[0], z_dim)))
        z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
        loc_cat, loc_num, scale_num = decoder.forward(z)
        cat_obs = pyro.sample("cat_obs", dist.Bernoulli(loc_cat).to_event(1), obs=y[:,-1])
        num_obs = pyro.sample("num_obs", dist.LogNormal(loc_num, scale_num).to_event(1), obs=y[:,:-1])
        return z, cat_obs, num_obs
    
def guide(data):
    y = data[:,-3:]
    X = data[:,:-3]
    pyro.module("encoder", encoder)
    with pyro.plate("data", X.shape[0]):
        z_loc, z_scale = encoder(X)
        pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

What I find strange is that when I run the predictive for a test set there is no variability in the predictions for the observed variables y and they totally agree with the ground truth, while there is variability for the latent z. Bellow is how I run the predictive. Am I misunderstanding something? Is there a bug in my code? Any help would be much appreciated.

predictive = pyro.infer.Predictive(model, guide=guide, num_samples=800)
svi_samples = predictive(test_set.dataset[test_set.indices])

you’re passing in data=test_set.dataset[test_set.indices] and in model data[:,-3:] is observed. i.e. observed means fixed and known and not subject to change. so predictive will just spit data[:,-3:] back at you. you want to change your model to take in X, y and and pass in X=blah, y=None during prediction. that way y will be sampled.

Thanks a lot. Such a silly mistake. Everything works as it should now