Implementation & normalizing flow in matrix normal distribution

Hi, I’m working on a model where the likelihood follows a matrix normal distribution, X ~ MN_{n,p} (M, U, V). I’m using conjugate priors:

M ~ MN
U ~ Inverse Wishart
V ~ Inverse Wishart

As a result, I believe the posterior distribution should also follow a matrix normal distribution.

  1. Is there a way to implement the matrix normal distribution in Pyro?
  2. If I replace the conjugate priors with non-conjugate priors, resulting in a potentially unpredictable posterior, is it possible to apply normalizing flows in Pyro to approximate a matrix-normal-like posterior?

i don’t think there’s a matrix normal distribution in pyro but there is in numpyro:
https://num.pyro.ai/en/stable/distributions.html#numpyro.distributions.continuous.MatrixNormal

That’s great news. I was trying to implement it pyro first something like below. This is my first time using pyro. Is using numpyro instead of pyro for normalizing flow ok in my case? I am not sure the main difference between numpyro and pyro

from pyro.distributions import TorchDistribution
from pyro.distributions.util import broadcast_shape
from torch.distributions import constraints
import torch

class MatrixNormal(TorchDistribution):
    arg_constraints = {'M': constraints.real, 'U': constraints.positive_definite, 'V': constraints.positive_definite}
    support = constraints.real

    def __init__(self, M, U, V, validate_args=None):
        self.M = M
        self.U = U
        self.V = V
        # Infer batch shape from broadcasted shapes of M, U, and V
        batch_shape = broadcast_shape(M.shape[:-2], U.shape[:-2], V.shape[:-2])
        event_shape = M.shape[-2:]
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

    def sample(self, sample_shape=torch.Size()):
        # Use the Kronecker product approach for sampling
        return matrix_normal_rvs(self.M, self.U, self.V)

    def log_prob(self, value):
        # Implement log-probability for Matrix Normal here, if needed
        pass

thet main difference between pyro and numpyro is pytorch vs jax. both have support for normalizing flows.

i don’t know how high-dimensional your setup is, but there is no guarantee a normalizing flow will work well in practice.