Porting a Probabilistic NMF model from PyMC3 to Numpyro

I’m trying to port a model from PyMC3 but I’m not sure how to construct a guide. The model is a variant of Nonnegative Matrix Factorization Y=BQ where both B and Q are positive matrices. It’s different from standard NMF because I do not directly observe the matrix Y but rather a linear mapping on the columns of Y. I place a Dirichlet prior on the columns of B and a half normal prior on Q. The pymc3 code looks like this:

import pymc3 as pm
import theano.tensor as tt

with pm.Model() as model:
    PositiveNormal = pm.Bound(pm.Normal, lower=0.0)
    BT = pm.Dirichlet("BT", a=0.8 * np.ones(npix), shape=(K, npix))
    QT = PositiveNormal("QT", mu=0.0, sigma=1e3, shape=(L, K))
    # Pixel basis -> Ylm basis
    BT_ylm = tt.dot(BT, PInvT)

    # NMF Y=BQ
    YT = tt.dot(QT, BT_ylm)

    # Map columns of Y to the observables
    f = tt.batched_dot(Apad, YT)    
    pm.Potential("obs", -0.5 * tt.sum((f - fpad) ** 2 * ivarpad))
    
    # Fit ADVI
    advi = pm.ADVI()
    res_pymc3 = advi.fit(n=30000, obj_optimizer=pm.adam(learning_rate=5e-03))

In this case, pymc3 automatically transforms the Dirichlet and the half normal distribution on the real line and I assume an independent Gaussian over the transformed space for the VI approximation . This works pretty well.

I’m not sure how to implement a guide on the transformed space in numpyro which is done automatically by pymc3. I tried doing the following

import jax.numpy as jnp

def model():
    BT = numpyro.sample("BT", dist.TransformedDistribution(
        dist.Dirichlet(0.8*jnp.ones((K, npix))), 
        dist.transforms.StickBreakingTransform()
    ))
    QT = numpyro.sample("QT", dist.TransformedDistribution(
        dist.HalfNormal(scale=1e3*jnp.ones((L, K))), 
        dist.transforms.AbsTransform()
    ))
    
    # Pixel basis -> Ylm basis
    BT_ylm =  BT.dot(PInvT)

    # NMF Y=BQ
    YT = jnp.dot(QT, BT_ylm)
    
    # Map columns of Y to the observables
    f = jnp.einsum('ijk,ik->ij', Apad, YT)
    numpyro.sample(f"obs", dist.Normal(f.reshape(-1)[mask], ferrpad.reshape(-1)[mask]), 
       obs=fpad.reshape(-1)[mask])
    
def guide():
    mu_BT_transf = numpyro.param("mu_BT_trans", np.ones((K, npix)))
    sd_BT_transf = numpyro.param("sd_BT_tranf", 0.1*np.ones((K, npix)), constraint=constraints.positive)
    BT = numpyro.sample("BT", dist.Normal(mu_BT_transf, sd_BT_transf))
    
    mu_QT_transf = numpyro.param("mu_QT_transf", np.ones((L, K)))
    sd_QT_transf = numpyro.param("sd_QT_transf", 0.1*np.ones((L, K)), constraint=constraints.positive)
    QT = numpyro.sample("QT", dist.Normal(mu_QT_transf, sd_QT_transf))

but I have no idea if this is the right approach. I’m getting shape errors if I run this model.

Hi @fbartolic, if you use autoguides such as AutoDiagonalNormal, they will automatically do the transform job for you.

I tried using the AutoDiagonalNormal guide but for some reason the numpyro model doesn’t seem to converge, despite the fact that the two models is exactly the same. I also use the same optimizer and learning rate and (I think) the same initial parameters for the VI distribution. The PyMC3 model converges easily even with the default initialization. Here’s a reproducible example in case you want to have a look @fehiepsi.

In general, I don’t want to use the autoguides because I want to impose more structure on the guide, I was trying out ADVI just to get a simple model working.

@fbartolic I think you can resolve the issue by either

  • using
guide.sample_posterior(random.PRNGKey(2), params, (500,))

to get posterior samples

  • or specify params=params in Predictive.

NumPyro does not have a global param store as in Pyro so this is needed to tell Predictive which parameters are used for prediction.

I want to impose more structure on the guide

If you want to impose more structure, then you should apply transforms (if needed) such as StickBreakingTransform or AbsTransform in the guide, not in the model.

Yep that’s right, thanks! It occur to me that I’m not sampling the guide distribution properly.

If you want to impose more structure, then you should apply transforms (if needed) such as StickBreakingTransform or AbsTransform in the guide, not in the model.

I see, thank you!