Variational inference for matrix normal likelihood with gaussian mixture prior

Hi all,

I’m working on a model where the likelihood is a matrix normal distribution, and the prior over each edge of the graph is a spike-and-slab inspired Gaussian mixture. The spike component is a narrow Gaussian centered at zero, and the slab components are wider Gaussians with a mixture over multiple variances.

I want to perform variational inference to approximate the posterior over the graph matrix G. Since the prior is a non-conjugate Gaussian mixture and the likelihood is matrix normal, I’m wondering:

  • Do I need to derive a custom variational family for this setup, or is there something in Pyro/NumPyro that would work out of the box?
  • Is it okay to use a default like a mean-field Gaussian, or would that be too simple given the structure of the matrix normal?
  • Would black-box variational inference with Monte Carlo estimation of the ELBO work well here?
  • Has anyone tried using a matrix-variate variational distribution (like a matrix normal) instead of a factorized normal?

I already have the code working with NUTS, but it’s slow. I’m mainly looking for a faster and scalable variational alternative. Would love guidance on how to proceed or any examples if something similar has been done.

Thanks!

hard to say without further details but given the spike-and-slab-like prior it sounds like you probably have a pretty multi-modal posterior. generally speaking, even if you work really hard to define a custom variational family, variational inference doesn’t generally do an amazing job of fully capturing such complex posteriors (note that mcmc can also struggle here, especially in high dimensions). with the right family and well-tuned optimization etc, you might hope to capture some of the main modes and/or capture the principal distributional characteristics of the main mode, but you’re unlikely to capture the nuances of every last mode.

so a lot of it depends on if you’re ok with that. maybe if you’re just doing prediction with inferred model parameters, you don’t care that much about the details of the posterior. on the other hand it you want to make strong statements about some scientific hypothesis being or not being supported by the data because of the details of your approximate posterior, you probably care more. it really depends.

Thank you, I guess the goal is finding the major mode and maybe second/third highest mode. Surprisingly nuts works pretty ok. Using nuts in 100D takes less than 30 mins. I also thought about mixing mode seeking algorithm (pathfinder) as initialization of nuts which is reasonable? or Parallel tempering? However it would be nice to have a analytical VI which I am still struggling with