TraceEnum_ELBO

Hi all, I see that TraceEnum_ELBO has made its way into numpyro. I’m curious, if it is capable of working with a model we typically use.

Specifically, let \beta = \gamma \cdot b, where \gamma \sim \text{Multi}(1, \pi), (i.e. one-hot encoding of Categorical) and b \sim \text{Norm}(0, \sigma^2_b). We have the linear regression model y = X\beta + \epsilon (where X is shape (n, p)) and would like to compute an approximate posterior Q(\beta)=Q(\gamma)Q(b | \gamma). I’ve toyed around with using infer={“enumerate”: “parallel”} for \gamma in the prior, but this seems to prohibit its inference in the approximate posterior (its marginalized out). Is it possible to define a guide using enumeration in this setting that results in p conditional parameters for b given \gamma (i.e. posterior mean and variance, given what the selected column in X is)?

Not sure if I understand your question correctly but you should be able to use enumeration in the guide. Can you show the code you have?

Sure thing, here is some pseudo code that captures most of what I’d like to model.

def model(X, y):
    n_dim, p_dim = X.shape

    pi = jnp.ones(p_dim) / float(p_dim)
    sigma_b = 1e-3

    gamma = npro.sample("gamma", dist.Categorical(pi))
    b = npro.sample("b", dist.Normal(0., sigma_b))

    sigma_e = nrpo.param("sigma_e", 0.9, constraint=constraints.positive)
    with npro.plate("N", n_dim):
       g = X[:, gamma] * b
       npro.sample("y", dist.Normal(g, sigma_e), obs=y)

def guide(X, y):
    n_dim, p_dim = X.shape
    
    alpha = nrpo.param("alpha", jnp.ones(p_dim) / float(p_dim), constraints=constraints.simplex)
    gamma = nrpo.sample("gamma", dist.Categorical(alpha), infer={“enumerate”: “parallel”})

    post_sigma_b = nrpo.param("post_sigma_b", jnp.ones(p_dim) * 1e-3, constraints=constraints.positive)
    post_mu_b = nrpo.param("post_mu_b", jnp.zeros(p_dim))
    b = nrpo.sample("b", dist.Normal(Vindex(post_mu_b)[..., gamma], Vindex(post_sigma_b)[..., gamma]))

It all looks good to me. Are there any issues with inference using TraceEnum_ELBO?

Thanks @ordabayev. For one, numpyro is not happy about the different shapes of the prior for b (i.e. shape=()), and approximate conditional posterior for b (i.e. shape(p,)). Fixing that to use a flat, shared prior over b still results in non-sensical inference. For example,

def model(X, y, l_dim=1):
    n_dim, p_dim = X.shape

    pi = jnp.ones(p_dim) / float(p_dim)
    sigma_b = jnp.ones(p_dim) * 1e-3

    gamma = numpyro.sample("gamma", dist.Categorical(pi))
    b = numpyro.sample("b", dist.Normal(0.0, sigma_b[..., gamma]))

    sigma_e = numpyro.param("sigma_e", 0.9, constraint=constraints.positive)
    with numpyro.plate("N", n_dim):
        g = Vindex(X)[:, gamma] * b
        numpyro.sample("y", dist.Normal(g, sigma_e), obs=y)

    return


def guide(X, y, l_dim=1):
    n_dim, p_dim = X.shape

    alpha = numpyro.param(
        "alpha", jnp.ones(p_dim) / float(p_dim), constraints=constraints.simplex
    )
    gamma = numpyro.sample(
        "gamma", dist.Categorical(alpha), infer={"enumerate": "parallel"}
    )

    post_sigma_b = numpyro.param(
        "post_sigma_b", jnp.ones(p_dim) * 1e-3, constraints=constraints.positive
    )
    post_mu_b = numpyro.param("post_mu_b", jnp.zeros(p_dim))
    b = numpyro.sample(
        "b",
        dist.Normal(Vindex(post_mu_b)[..., gamma], Vindex(post_sigma_b)[..., gamma]),
    )
    return

#[...]
adam = optim.Adam(step_size=0.005)
svi = SVI(model, guide, adam, TraceEnum_ELBO(max_plate_nesting=10))

results = svi.run(
        rng_key_run,
        args.epochs,
        X=X,
        y=y,
        l_dim=args.l_dim,
        progress_bar=True,
        stable_update=True,
)

Here is some summarized output showing nan due to incorrect posterior std inference:

Param Sites:
      sigma_e
Sample Sites:
   gamma dist     |
        value     |
     log_prob     |
       b dist     |
        value     |
     log_prob     |
      N plate 400 |
       y dist 400 |
        value 400 |
     log_prob 400 |

results.losses = Array([15483.41834474,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan], dtype=float64, weak_type=True)

results.params["post_sigma_b"] = Array([-0.004, -0.004,  0.006, -0.004,  0.006, -0.004,  0.006,  0.006,
        0.006, -0.004, -0.004,  0.006, -0.004, -0.004,  0.006, -0.004,
       -0.004, -0.004,  0.006,  0.006], dtype=float64)

I’ve uploaded the entire simulation and sample code here.

didn’t look at your model in great detail but one thing you might try is to initialize to a smaller post_sigma_b e.g.

post_sigma_b = numpyro.param(
        "post_sigma_b", jnp.ones(p_dim) * 1e-5, constraints=constraints.positive
    )

Thanks @martinjankowiak. I ended up manually converting probabilities to logits and stds to log_std, which got rid of the nans during inference. However, it seems to still not really perform inference.

It would be great if someone could check my understanding of what is going on. Is it the case that infer={"enumerate": "parallel"} in the guide results in computing Q(b) = \sum_j Q(b | \gamma_j = 1), which results in model samples from the guide average of E_Q[b] = \sum_j E_Q[b | \gamma_j = 1]?

The posterior mean I’d like for Q(\beta=b \cdot \gamma) would be E_Q[\gamma] \cdot [E_Q[b | \gamma_1 = 1], \dotsc, E_Q[b | \gamma_p = 1]].

Since gamma is enumerated b is sampled for each enumerated value of gamma (not sampled from the average distribution).

I don’t think that the expectation for beta could be factorized like that. You probably have to do it in general form:

E_q(gamma)q(b|gamma)[b * gamma] ~ 1/N * SUM_j [b_ij * gamma_j]  

where gamma_j are enumerated values of gamma and b_ij are Monte Carlo samples from q(b|gamma_j) and i ranges from 1 to N.