Putting prior on mean of matrix normal distribution

Hi, I am working on a network graph dataset G, represented as an adjacency matrix, and I aim to use the Matrix Normal (MN) distribution as the likelihood in my model. The MN distribution has three key inputs: the mean, row covariance (U), and column covariance (V).

  1. R^ : observed data for the MN distribution,
  2. U and V: the row and column covariance matrices for the MN distribution, respectively,
  3. G: the adjacency matrix of the graph

The mean of the Matrix Normal (MN) distribution in my model is computed as (I−G)^-1 where G is the observed graph matrix provided in the file (G matrix). My objective is to incorporate a mixture normal prior on every edge of the observed G, regardless of whether the edge is connected or not.

For example, if G is a graph with D=10, it will have 10×10=100 possible edges. I want to apply the mixture normal prior on all 100 edges and then compute (I−G)^−1 using this G with the prior applied. The resulting (I−G)^−1 will then serve as the mean for the MN distribution in my model, as described in the equation below.

The issue with my current implementation is that it replaces G entirely with a random matrix sampled from the mixture normal prior, instead of applying the prior to each edge of the observed G from the file.

Could someone suggest how to modify my model to address this?

R^ ~ MN( (I-G)^-1, U, V)

Rhat_df = pd.read_csv("R_hat.csv")
U_df = pd.read_csv("U.csv",header=None)
V_df = pd.read_csv("V.csv",header=None)
G_df = pd.read_csv("G.csv")

D = G_df.shape[0]
num_zero_entries = np.sum(G_df.values == 0)
total_entries = D * D
pi_0 = num_zero_entries / total_entries
pi_1 =1-pi_0

U_lower= jnp.linalg.cholesky(U_df.values)
V_lower= jnp.linalg.cholesky(V_df.values)

def model(obs_data, D=10, epsilon=1e-5, pi_array=np.array([pi_0,pi_1]), n_components=100):
    mixing_dist = dist.Categorical(probs=pi_array)
    component_dists = [
        dist.Normal(loc=0.0, scale=0.1),
        dist.MixtureSameFamily(
            dist.Categorical(probs=jnp.ones(n_components) / n_components),
            dist.Normal(loc=jnp.zeros(n_components), scale=jnp.linspace(0.1,10,n_components))
        ),
    ]
    G = numpyro.sample('G', dist.MixtureGeneral(mixing_dist, component_dists).expand((D,D)))
    I = jnp.eye(D)
    I_minus_G = I - (G + epsilon) * jnp.eye(D)
    R_mean = jnp.linalg.inv(I_minus_G)

    R_hat_obs = numpyro.sample('R_hat_obs',
                               dist.MatrixNormal(loc=R_mean, scale_tril_row=U_lower, scale_tril_column=V_lower),
                               obs=obs_data
                              )