Conditional reparameterization

I would like to apply a conditional reparameterization to my model, e.g., to add an offset to sample a given sample z as discussed in Add explicit reparametrizer. by tillahoffmann · Pull Request #1754 · pyro-ppl/numpyro · GitHub. I’ve reproduced the relevant code below. (edited to fix variable names)

import numpyro
from numpyro import distributions as dists


def model():
    z = numpyro.sample("z", dists.Normal())
    alpha = numpyro.param("alpha", 0.)
    transform = dists.transforms.AffineTransform(alpha * z, 1)
    reparam = numpyro.infer.reparam.ExplicitReparam(transform)
    with numpyro.plate("n", 10), numpyro.handlers.reparam(config={"a": reparam}):
        a = numpyro.sample("a", dists.Normal())
    # observation model ...

This of course strongly couples the model definition and inference such that I have to update the model every time I want to try a different reparameterization. Is it possible to apply such conditional reparameterizations without integrating them into the model? I couldn’t see anything obvious in the Reparam API (numpyro/numpyro/infer/reparam.py at f478772b6abee06b7bfd38c11b3f832e01e089f5 · pyro-ppl/numpyro · GitHub). Maybe passing a trace object to a custom Reparam implementation might do the trick because the trace is updated with each sample statement.

>>> with numpyro.handlers.seed(rng_seed=42), numpyro.handlers.trace() as trace:
...     print(trace.keys())
...     numpyro.sample("a", numpyro.distributions.Normal())
...     print(trace.keys())
...     numpyro.sample("b", numpyro.distributions.Normal())
...     print(trace.keys())
odict_keys([])
odict_keys(['a'])
odict_keys(['a', 'b'])

I thought I’d get your perspective first before going down the wrong rabbit hole. Thank you for your input!

This of course strongly couples the model definition and inference such that I have to update the model every time I want to try a different reparameterization.

Could you clarify? By the way, it seems that b is missing in your model.

I’ve updated the variable names above, sorry.

I may also have fallen victim to the XY problem. So here’s the actual problem.

Motivation

I’m building a tensor factorization model for an order three tensor y such that

y_{ijt}\sim\mathsf{Normal}\left(\mu+a_i+b_j+z_t+A_{it}+B_{jt}+C_{ij},\sigma^2_{ij}\right)

with shrinkage priors on a, b, and C, Kalman-style priors on z, each row of A, and each row of B, and a wide normal prior on \mu.

If the observation variance \sigma^2 is relatively small, i.e., my data are informative, then each contribution to \hat y = \mu+a_i+\ldots is anti-correlated with the others under the posterior. That’s because adding a constant \delta to \mu and subtracting the same constant from all elements of a leaves the likelihood unchanged. The parameters are only (weakly) identified by the priors.

That’s problematic for sampling the posterior because the sampler struggles to explore the whole posterior with a diagonal mass matrix. It’s also problematic for variational inference using an AutoDiagonalNormal guide because the posterior is much too narrow due to the correlation (an AutoLowRankMultivariateNormal helps a bit but doesn’t address the next challenge).

Even if we don’t care about getting the posterior width wrong, the correlation leads to very slow convergence of the variational parameters. That’s because the contribution to the gradient of the sampled ELBO is dominated by the “high dimensional valley” of the likelihood. For example, suppose you sample parameters \theta=\left(\mu,a,b,\ldots\right) values such that \hat y=\mu+\ldots is a little different from the data y. Then you get a really strong gradient pushing you back down into the valley but only a small gradient contribution from the priors. If we’re using something like the Adam optimizer with learning rate \eta, the steps taken are

\theta\rightarrow\theta-\eta\frac{\text{exponential moving first moment of gradients}}{\sqrt{\text{exponential moving second moment of gradients}}}.

The denominator is large (because we’re bouncing around the valley and picking up all that variance) but the first moment is small in magnitude because the only thing that’s pushing us to the solution is the weak prior gradients (subject to satisfying the valley constraint).

Whenever this happens, it’s also impossible to use the ELBO to diagnose convergence because the noise from sampling the likelihood completely dominates any improvement we get from getting closer to satisfying the priors. I’ve got a little two-parameter toy problem I can share if of interest.

A possible solution

To address both the underestimated posterior variance and the slow convergence, I thought I’d reparametrize my model by sampling each of the contributions with an offset conditional on the previously sampled contributions. E.g., I might have something like this.

def model(n_a, ...):
    # Sample the grand mean `mu` and scale `sigma_a` of the effects `a`.
    mu = numpyro.sample("mu", numpyro.distributions.Normal(0, 100))
    sigma_a = numpyro.sample("sigma_a", numpyro.distributions.HalfCauchy())
    # Construct an explicit reparameterization for `a` as an affine transform of a 
    # base distribution with a learnable offset that depends on the grand mean `mu`.
    config = {"a": numpyro.reparam.ExplicitReparam(
        numpyro.distributions.transforms.AffineTransform(
            numpyro.param("xi_a_mu", jnp.zeros(n_a)) * mu,
            1,
    )}
    # Apply the reparameterization and sample elements of `a` iid using a plate.
    with (
        numpyro.handlers.reparam(config=config), 
        numpyro.plate("n_a", n_a),
    ):
        a = numpyro.sample("a", numpyro.distributions.Normal(0, sigma_a))
    ...

Now the variational parameters of the reparameterized distributions for a only encode deviations from \xi_{a\mid\mu} \times \mu. We initialize with \xi_{a\mid\mu}=0 such that we start with a standard setup. This seems to work for the two-parameter toy model I mentioned, addressing both the convergence and uncertainty estimation.

However, as is evident from the code snippet above, implementing this for the whole tensor factorization model is very verbose. Further, if I want to experiment with a different parameterization, I have to change the model (rather than just the reparameterization config). It “feels like” I should be able to decouple the two to more easily experiment with different reparametrization configurations.

But maybe I’m approaching this problem in a weird way? Thank you for reading the essay above!