Hi, I’m working on a model where the likelihood follows a matrix normal distribution, X ~ MN_{n,p} (M, U, V). I’m using conjugate priors:
M ~ MN
U ~ Inverse Wishart
V ~ Inverse Wishart
As a result, I believe the posterior distribution should also follow a matrix normal distribution.
Is there a way to implement the matrix normal distribution in Pyro?
If I replace the conjugate priors with non-conjugate priors, resulting in a potentially unpredictable posterior, is it possible to apply normalizing flows in Pyro to approximate a matrix-normal-like posterior?
That’s great news. I was trying to implement it pyro first something like below. This is my first time using pyro. Is using numpyro instead of pyro for normalizing flow ok in my case? I am not sure the main difference between numpyro and pyro
from pyro.distributions import TorchDistribution
from pyro.distributions.util import broadcast_shape
from torch.distributions import constraints
import torch
class MatrixNormal(TorchDistribution):
arg_constraints = {'M': constraints.real, 'U': constraints.positive_definite, 'V': constraints.positive_definite}
support = constraints.real
def __init__(self, M, U, V, validate_args=None):
self.M = M
self.U = U
self.V = V
# Infer batch shape from broadcasted shapes of M, U, and V
batch_shape = broadcast_shape(M.shape[:-2], U.shape[:-2], V.shape[:-2])
event_shape = M.shape[-2:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def sample(self, sample_shape=torch.Size()):
# Use the Kronecker product approach for sampling
return matrix_normal_rvs(self.M, self.U, self.V)
def log_prob(self, value):
# Implement log-probability for Matrix Normal here, if needed
pass