Noob question. Let’s take the toy example from Example: Toy Mixture Model With Discrete Enumeration — Pyro Tutorials 1.8.6 documentation. There, we have a function for the model and a function for the guide. What would be different, if we put everything into a model like this:
@pyro.infer.config_enumerate def model(prior, obs, num_obs): a = pyro.param("a", prior["A"], constraint=constraints.positive) p_A = pyro.sample("p_A", dist.Beta(a[0], a[1])) b = pyro.param("b", prior["B"], constraint=constraints.positive) p_B = pyro.sample("p_B", dist.Beta(b[:, 0], b[:, 1]).to_event(1)) c = pyro.param("c", prior["C"], constraint=constraints.positive) P_C = pyro.sample("p_C", dist.Beta(c[:, 0], c[:, 1]).to_event(1)) with pyro.plate("data_plate", num_obs): A = pyro.sample("A", dist.Bernoulli(p_A.expand(num_obs)), obs=obs["A"]) # Vindex used to ensure proper indexing into the enumerated sample sites B = pyro.sample( "B", dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), infer={"enumerate": "parallel"}, ) pyro.sample("C", dist.Bernoulli(Vindex(p_C)[B.type(torch.long)]), obs=obs["C"])
and use
def noguide(prior, obs, num_obs): pass