I was trying to implement prediction model using SIR (based on Example: Epidemiological inference via HMC — Pyro Tutorials 1.8.4 documentation) but use SVI instead. However, it has been giving me inf loss. Here’s my guide and model, I wonder if I misunderstood anything.
def guide(population, recovery_time, data):
mu_R0 = pyro.param('mu_R0', torch.tensor(2.0))
var_R0 = pyro.param('var_R0', torch.tensor(0.5**2),
constraint=constraints.positive)
alpha_rho = pyro.param("alpha_rho", torch.tensor(5),
constraint=constraints.positive)
beta_rho = pyro.param("beta_rho", torch.tensor(1.5),
constraint=constraints.positive)
R0 = pyro.sample("R0", dist.Normal(mu_R0, var_R0))
rho = pyro.sample("rho", dist.Beta(alpha_rho, beta_rho))
def model(population, recovery_time, data):
tau = recovery_time
R0 = pyro.sample("R0", dist.Normal(2.0, 0.5**2))
rho = pyro.sample("rho", dist.Beta(5, 1.5))
rate_s = -R0 / (tau * population) # beta
prob_i = 1 / (1 + tau) # probability of recovering
S = torch.tensor(population - 1.0)
I = torch.tensor(1.0)
for t, datum in enumerate(data):
S2I = pyro.sample(f'S2I_{t}', dist.Binomial(S, abs(-(rate_s * I).expm1())))
I2R = pyro.sample(f'I2R_{t}', dist.Binomial(I, prob_i))
S = pyro.deterministic(f'S_{t}', S - S2I)
I = pyro.deterministic(f'I_{t}', I + S2I - I2R)
pyro.sample(f'obs_{t}', dist.ExtendedBinomial(S2I, rho), obs=datum)
Thank you for your help.