SVI for SIR gives inf loss

I was trying to implement prediction model using SIR (based on Example: Epidemiological inference via HMC — Pyro Tutorials 1.7.0 documentation) but use SVI instead. However, it has been giving me inf loss. Here’s my guide and model, I wonder if I misunderstood anything.

def guide(population, recovery_time, data):
    mu_R0 = pyro.param('mu_R0', torch.tensor(2.0))
    var_R0 = pyro.param('var_R0', torch.tensor(0.5**2), 

    alpha_rho = pyro.param("alpha_rho", torch.tensor(5),
    beta_rho = pyro.param("beta_rho", torch.tensor(1.5),

    R0 = pyro.sample("R0", dist.Normal(mu_R0, var_R0))
    rho = pyro.sample("rho", dist.Beta(alpha_rho, beta_rho))
def model(population, recovery_time, data):
    tau = recovery_time
    R0 = pyro.sample("R0", dist.Normal(2.0, 0.5**2))
    rho = pyro.sample("rho", dist.Beta(5, 1.5))

    rate_s = -R0 / (tau * population) # beta
    prob_i = 1 / (1 + tau) # probability of recovering
    S = torch.tensor(population - 1.0)
    I = torch.tensor(1.0)
    for t, datum in enumerate(data):
        S2I = pyro.sample(f'S2I_{t}', dist.Binomial(S, abs(-(rate_s * I).expm1())))
        I2R = pyro.sample(f'I2R_{t}', dist.Binomial(I, prob_i))
        S = pyro.deterministic(f'S_{t}', S - S2I)
        I = pyro.deterministic(f'I_{t}', I + S2I - I2R)
        pyro.sample(f'obs_{t}', dist.ExtendedBinomial(S2I, rho), obs=datum)

Thank you for your help.

Hi @vltanh, I’d recommend starting with Pyro’s built-in framework for epidemiological forecasting pyro.contrib.epidemiology, rather than the low-level inference demonstration example you pointed to. There are good high-level tutorials about contrib.epidemiology, starting with the intro tutorial. Inference via SVI is built-in with the .fit_svi() method.