Help understanding an epidemic model?

Hi! In trying to learn Pyro, I ran across the following model from:

@config_enumerate
def reparameterized_discrete_model(args, data):
    # Sample global parameters.
    rate_s, prob_i, rho = global_model(args.population)

    # Sequentially sample time-local variables.
    S_curr = torch.tensor(args.population - 1.0)
    I_curr = torch.tensor(1.0)
    for t, datum in enumerate(data):
        # Sample reparameterizing variables.
        # When reparameterizing to a factor graph, we ignored density via
        # .mask(False). Thus distributions are used only for initialization.
        S_prev, I_prev = S_curr, I_curr
        S_curr = pyro.sample(
            "S_{}".format(t), dist.Binomial(args.population, 0.5).mask(False)
        )
        I_curr = pyro.sample(
            "I_{}".format(t), dist.Binomial(args.population, 0.5).mask(False)
        )

        # Now we reverse the computation.
        S2I = S_prev - S_curr
        I2R = I_prev - I_curr + S2I
        pyro.sample(
            "S2I_{}".format(t),
            dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()),
            obs=S2I,
        )
        pyro.sample("I2R_{}".format(t), dist.ExtendedBinomial(I_prev, prob_i), obs=I2R)
        pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum)

I have a few questions about what the strategy is here:

  1. What are we actually enumerating over? The discrete variables either have a provided observation — so I’m assuming they just add +Dist(obs | pars) to the logpdf — or are masked, so I’m assuming are not part of enumeration?

  2. The comments say:

By reparameterizing, we have converted to coordinates that make the model Markov.

It seems to me that the original model is perhaps already Markov, as I can’t see the current state depending on anything other than the previous state, if the state is taken to be (S, I, R). So I’m wondering in what sense the new model is “more Markov”.

  1. Why do we initialize with masked draws from a Binomial instead of just a fixed value, if it doesn’t contribute to the logpdf, and is therefore not enumerated over?

If there is a reference to this strategy somewhere else, that would also be sure helpful. Thanks!

i don’t think there’s a detailed write-up anywhere but you can probably recover some of the logic by reading pyro.contrib.epidemiology.models — Pyro documentation

1 Like