In Pyro, how do infer continuous params in a training step, then a discrete latent conditional on first step?

This is an embarrassingly question. But I’m developing course lectures and want to make sure I teach student best practices instead of what’s right to me.

Imagine this were the true DGP

def true_dgp(jenny_inclination, brian_inclination, window_strength):
    # input variables are numbers between 0 and 1
    jenny_throws_ball = jenny_inclination > 0.2
    brian_throws_ball = brian_inclination > 0.8
    if jenny_throws_ball and brian_throws_ball:
        strength_of_impact = 0.8
    elif jenny_throws_ball or brian_throws_ball:
        strength_of_impact = 0.6
    else:
        strength_of_impact = 0.0
    window_breaks = window_strength < strength_of_impact
    return jenny_throws_ball, brian_throws_ball, window_breaks

Above, jenny_inclination, brian_inclination, and window_strength are uniformly distributed 0 and 1.

Here is the model of the DGP.

def model():
    # Priors that represent uncertainty about the parameters in our model
    ρ_a = pyro.sample("ρ_a", pyro.distributions.Beta(1, 1))
    ρ_b = pyro.sample("ρ_b", pyro.distributions.Beta(1, 1))
​
    # Our symbolic explanation generator of the DGP
    jenny_throws_ball = pyro.sample(
        "jenny_throws_ball", pyro.distributions.Bernoulli(ρ_a)
    )
    brian_throws_ball = pyro.sample(
        "brian_throws_ball", pyro.distributions.Bernoulli(ρ_b)
    )
​
    ρ_strength_of_impact = 0.0
    if jenny_throws_ball and brian_throws_ball:
        ρ_strength_of_impact = 0.8
    elif jenny_throws_ball or brian_throws_ball:
        ρ_strength_of_impact = 0.4
​
    window_breaks = pyro.sample(
        "jenny_throws_ball",
        pyro.distributions.Bernoulli(ρ_strength_of_impact)
    )
    return jenny_throws_ball, brian_throws_ball, window_breaks

In my lecture notes, I use trueDGP to sample 100 observations of jenny_throws_ball, brian_throws_ball, window_breaks. Call this training_data. Then we are interested in a post training query, such as P(Jenny threw ball | window is broken). So evidence = {‘window_breaks’: torch.tensor(1)}.

What’s the best way to do this in Pyro? My instinct says condition(condition(model, training_data), evidence) (or use the “obs” argument in pyro.sample for the first condition) and then use Importance sampling?

Though that doesn’t seem wise if the space of the continuous parameters that get optimized in the first condition is complex. Moreover, an ideal workflow would be something that learned the continuous “weight” parameters once, then use that object as an input to an inference algorithm that answers conditional queries defined on jenny_throws_ball, brian_throws_ball, window_breaks that come up later.

I was thinking one could use NUTS or SVI to inferring the continuous parameters, then use that “posterior” object in a guide function targeting the other variables, perhaps using importance sampling?

That seems pretty complicated. Perhaps I do something with the Predictive class? I feel like I’m missing something obvious, sorry.

Hi @osazuwa! I assume DGP means “data generating process”? I guess each query like P(Jenny threw ball | window is broken) involves different evidence and this requires inference. From Pyro’s perspective there’s no difference between data and evidence, so just as you’ll need fresh inference when data changes, you’ll need fresh inference when evidence changes.

I guess there are a few tactics to make query-time inference cheap. One tactic is to use importance sampling, as you suggest. Another tactic is to either enumerate over evidence (essentially precompute posteriors of all possible evidence) or to amortize. Amortization is probably the most scalable approach if your evidence space is large: you could train a VAE on (data, evidence) pairs where data is fixed and evidence is sampled from a distribution over the kinds of queries you expect (e.g. a uniform distribution).