Assembling a matrix from sampled components

Hi all,

I’m trying to infer the parameters of a two-species Lotka-Volterra system using SVI in Pyro. Right now, I’m considering the growth rates and initial condition to be known, and trying to infer only the interactions between the species. I’ve previously worked with NumPyro and had similar models working, but I’m not used to Pyro and am not sure where I’m going wrong.

The main code further below works, where the diagonal elements of the interaction matrix (a) are hard-coded, but not when these parameters are sampled, as shown here:

    with pyro.plate('n_taxa_', size=N, dim=-1):
        a = pyro.sample('a', dist.Normal(0.0, 0.1))

I would prefer to sample the diagonal and off-diagonal elements separately and then assemble the interaction matrix, because I’d eventually like to use a different distribution for the diagonal vs off-diagonal elements, but am open to better ways of approaching this.

Main code (working, but a is hard-coded):

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from torchdiffeq import odeint

def gLV(t, x, args):
    r, A = args
    return x * (r + A @ x)

def model(args, x_obs=None):
    # Growth rate r, init cond x0, t is time, number of species N
    r, x0, t = args
    N = len(r)
    
    # Diagonal elements of A, hard-coded
    a = torch.Tensor([0.0, 0.0])

    # Sample off-diagonal elements of A
    with pyro.plate('n_off_diag', size=(N * (N-1)), dim=-1):
        b = pyro.sample("b", dist.Normal(0.0, 0.1))

    # Create M and set off-diagonal elements
    A = torch.zeros(size=(N, N), dtype=b.dtype)
    mask = torch.eye(N, dtype=torch.bool)
    A[mask] = a
    A[~mask] = b

    func = lambda t, y : gLV(t, y, [r, A])
    x = odeint(func, x0, t, method='rk4', options=dict(step_size=0.1))
    x = torch.nan_to_num(x, 1e5)

    with pyro.plate('time', size=x.shape[0], dim=-2):
        with pyro.plate('n_taxa', size=N, dim=-1):
            pyro.sample("obs", dist.Normal(x, 0.1), obs=x_obs)

if __name__ == "__main__":
    dt = 0.01

    t = torch.arange(0, float(5), 1)
    x0_true = torch.Tensor([1., 1.])

    r_true = torch.Tensor([0.3, -0.6])
    A_true = torch.Tensor([
        [0.0, 0.01],
        [-0.02, 0.],
        ]
    )

    # For passing arguments to odeint function
    func = lambda t, y : gLV(t, y, [r_true, A_true])
    x = odeint(func, x0_true, t, rtol=1e-8, atol=1e-10, method='rk4', options=dict(step_size=0.1))

    guide = pyro.infer.autoguide.AutoDiagonalNormal(model)

    # Optimizer
    adam_params = {"lr": 0.001}
    optimizer = Adam(adam_params)

    # Inference
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    n_steps = 100
    for step in range(n_steps):
        svi.step([r_true, x0_true, t], x)

        if step % 50 == 0:
            print(step)
        
    print(guide.median())

Here’s the last few lines from the error when a is sampled:

ValueError: Expected parameter loc (Parameter of shape (4,)) of distribution Normal(loc: torch.Size([4]), scale: torch.Size([4])) to satisfy the constraint Real(), but found invalid values:
Parameter containing:
tensor([nan, nan, nan, nan], requires_grad=True)
           Trace Shapes:  
            Param Sites:  
  AutoDiagonalNormal.loc 4
AutoDiagonalNormal.scale 4
           Sample Sites:  

Thanks.