Discrete Latent Variable and Relaxation Question

I am still getting the hang of Pyro and SVI in general, and in particular, looking at inference with discrete latent variables. I spent some time on exploring enumeration [1-2], but was really struggling to grasp the dimensionality of enumerated variables. In particular, I was struggling in reading [2] to understand how I would know what to set as first_available_dim. However, @martinjankowiak’s dire comment at [3] largely put me off of this approach. Specifically:

unfortunately for all practical purposes i think it’s essentially impossible to get SVI to work in these kinds of models, especially if you want to push to large p

So I am looking at a relaxed Bernoulli [4]. My biggest question is simply if I am missing something, as this seems to work too well. The model:

import os
import torch
from torch.distributions import constraints
import numpy as np
import pandas as pd

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive, TraceEnum_ELBO, config_enumerate, infer_discrete

def model(predictors, h_pred, pheno):
    n_predictors = predictors.shape[1]
    gamma_0 = pyro.sample("g0", dist.Normal(0., 1.))
    covs = pyro.sample("covs", dist.Normal(torch.zeros(n_predictors), torch.ones(n_predictors)).to_event(1))
    pi = pyro.sample("pi", dist.Beta(2., 2.))
    z = pyro.sample("z", dist.RelaxedBernoulliStraightThrough(torch.tensor(1.0), probs=pi))
    l = (gamma_0 + torch.matmul(predictors, covs.squeeze()) + h_pred * z).squeeze(-1)
    with pyro.plate("pheno", predictors.shape[0]):
        p = pyro.sample("obs", dist.Bernoulli(logits=l), obs=pheno)
    return z,pi

def guide(predictors, h_pred, pheno):
    n_predictors = predictors.shape[1]
    g0_loc = pyro.param("g0_loc", torch.tensor(0.))
    g0_scale = pyro.param("g0_scale", torch.tensor(1.), constraint=constraints.positive)
    gamma_0 = pyro.sample("g0", dist.Normal(g0_loc, g0_scale))
    covs_loc = pyro.param("covs_loc", torch.zeros(n_predictors))
    covs = pyro.sample("covs", dist.Normal(covs_loc, 1.).independent(1))
    alpha = pyro.param("alpha", torch.tensor(2.), constraint=constraints.positive)
    beta = pyro.param("beta", torch.tensor(2.), constraint=constraints.positive)
    pi = pyro.sample("pi", dist.Beta(alpha, beta))
    z = pyro.sample("z", dist.RelaxedBernoulliStraightThrough(temperature=torch.tensor(1.0), probs=pi))
    l = (gamma_0 + torch.matmul(predictors, covs.squeeze()) + h_pred * z).squeeze(-1)


So Z is a latent Bernoulli variable governed by parameter pi. Near as I can tell, I get sensible results from directly deploying the relaxed Bernoulli as if I had been working with a un-relaxed one:

predictors = torch.rand([100, 3])
coefs = torch.tensor([1., -5, 0.2])
hidden_feat = np.zeros(100)
hidden_feat[51:] = 5 ## i.e. hidden_feat has a large impact on the outcome variable
hidden_feat = torch.tensor(hidden_feat)
pheno = torch.bernoulli(torch.sigmoid(0.2 + torch.matmul(predictors, coefs) + hidden_feat))

num_steps = 10000
lr0 = 0.01
gamma = 0.1
lrd =  gamma ** (1/num_steps)
optim = pyro.optim.ClippedAdam({'lr': lr0, 'lrd': lrd})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

for i in range(num_steps):
    loss = svi.step(predictors, hidden_feat, pheno)
    if i % 1000 == 0:

predictive = Predictive(model, guide=guide, num_samples=1000,
                        return_sites=['g0', 'covs', 'pi', 'z'])
samples = predictive(predictors, hidden_feat, pheno)
np.mean(samples['covs'].numpy(), axis=0)
np.mean(samples['pi'].numpy(), axis=0)
np.mean(samples['z'].numpy(), axis=0)

I just want to be sure I am using this distribution correctly, both in the model and guide. If there are better ways to attack this toy problem, I am also open to them (and additional education on enumeration, which I am still working on understanding).

[1] Inference with Discrete Latent Variables — Pyro Tutorials 1.8.4 documentation
[2] Tensor shapes in Pyro — Pyro Tutorials 1.8.4 documentation
[3] Discrete Inference with TraceGraph_ELBO - #5 by nmancuso
[4] Distributions — Pyro documentation