I’m trying to port a model from PyMC3 but I’m not sure how to construct a guide. The model is a variant of Nonnegative Matrix Factorization Y=BQ where both B and Q are positive matrices. It’s different from standard NMF because I do not directly observe the matrix Y but rather a linear mapping on the columns of Y. I place a Dirichlet prior on the columns of B and a half normal prior on Q. The pymc3 code looks like this:
import pymc3 as pm
import theano.tensor as tt
with pm.Model() as model:
PositiveNormal = pm.Bound(pm.Normal, lower=0.0)
BT = pm.Dirichlet("BT", a=0.8 * np.ones(npix), shape=(K, npix))
QT = PositiveNormal("QT", mu=0.0, sigma=1e3, shape=(L, K))
# Pixel basis -> Ylm basis
BT_ylm = tt.dot(BT, PInvT)
# NMF Y=BQ
YT = tt.dot(QT, BT_ylm)
# Map columns of Y to the observables
f = tt.batched_dot(Apad, YT)
pm.Potential("obs", -0.5 * tt.sum((f - fpad) ** 2 * ivarpad))
# Fit ADVI
advi = pm.ADVI()
res_pymc3 = advi.fit(n=30000, obj_optimizer=pm.adam(learning_rate=5e-03))
In this case, pymc3 automatically transforms the Dirichlet and the half normal distribution on the real line and I assume an independent Gaussian over the transformed space for the VI approximation . This works pretty well.
I’m not sure how to implement a guide on the transformed space in numpyro which is done automatically by pymc3. I tried doing the following
import jax.numpy as jnp
def model():
BT = numpyro.sample("BT", dist.TransformedDistribution(
dist.Dirichlet(0.8*jnp.ones((K, npix))),
dist.transforms.StickBreakingTransform()
))
QT = numpyro.sample("QT", dist.TransformedDistribution(
dist.HalfNormal(scale=1e3*jnp.ones((L, K))),
dist.transforms.AbsTransform()
))
# Pixel basis -> Ylm basis
BT_ylm = BT.dot(PInvT)
# NMF Y=BQ
YT = jnp.dot(QT, BT_ylm)
# Map columns of Y to the observables
f = jnp.einsum('ijk,ik->ij', Apad, YT)
numpyro.sample(f"obs", dist.Normal(f.reshape(-1)[mask], ferrpad.reshape(-1)[mask]),
obs=fpad.reshape(-1)[mask])
def guide():
mu_BT_transf = numpyro.param("mu_BT_trans", np.ones((K, npix)))
sd_BT_transf = numpyro.param("sd_BT_tranf", 0.1*np.ones((K, npix)), constraint=constraints.positive)
BT = numpyro.sample("BT", dist.Normal(mu_BT_transf, sd_BT_transf))
mu_QT_transf = numpyro.param("mu_QT_transf", np.ones((L, K)))
sd_QT_transf = numpyro.param("sd_QT_transf", 0.1*np.ones((L, K)), constraint=constraints.positive)
QT = numpyro.sample("QT", dist.Normal(mu_QT_transf, sd_QT_transf))
but I have no idea if this is the right approach. I’m getting shape errors if I run this model.