[New to Pyro] Is my implementation correct?

Hi, I am new to Pyro and trying to create this model after going through a few tutorials. This example by @gbernstein helped me get started, thank you!

I have a discrete latent variable, Z ~ Ber(t).
I have a set of variables Xj ( j = (1,2,3)) dependent on Z, P(Xj | Z) ~ Dir(a) and Xj can take one of 3 values.
X is observed and let’s say, data_x has a shape of [3,100], 100 observations for X1,X2,X3

Code below is working,

def model():
    p_Z = pyro.sample('p_Z', dist.Beta(1, 1))
    # repeat 3 times for 3 Xj
    p_X = pyro.sample('p_X', dist.Dirichlet(torch.ones(2,3).repeat(3,1,1)).to_event(2))
    with pyro.plate('data_plate', 100, 10) as i:
        obs = data_x[:,i]
        Z = pyro.sample('Z', dist.Bernoulli(p_Z), infer={"enumerate": "parallel"})
        for s_ind in  pyro.plate('source_plate', 3):
            X = pyro.sample('X_{}'.format(s_ind), dist.Categorical(Vindex(p_X)[s_ind,Z.type(torch.long)]), obs=obs[s_ind])

def guide():
    t = pyro.param('t', 5*torch.ones(2), constraint=constraints.positive)
    pyro.sample('p_Z', dist.Beta(t[0], t[1]))
    x = pyro.param('x', torch.ones(2,3).repeat(3,1,1))
    pyro.sample('p_X', dist.Dirichlet(x).to_event(2))

loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=2)
svi = pyro.infer.SVI(model,
                         pyro.optim.Adam({'lr': .01}),
    losses = []
for _ in tqdm(range(num_steps)):
    loss = svi.step()
  • Is this implementation right? Is there a more pyronic way of doing this?
  • I have sequential plate now, How to vectorize it? I was not able to get X tensor shape agree with what was expected.
  • I didn’t quite understand why p_X had to be to_event(2) for this to work and not to_event(1)
  • How do I calculate p(Z|X) for new observation?

Hi @sunil.su,

Your model and mean field guide look good, and I think we can speed it up by moving some plates around and vectorizing the inner plate. For reference see the tensor shapes tutorial.

First I think you could convert one of the .to_event()s to a pyro.plate. Since your plates overlap, let’s create both at the start and then reuse them later. Also the infer={"enumerate": "parallel"} is implied by the @pyro.infer.config_enumerate, so we can omit it.

def model():
    data_plate = pyro.plate("data_plate", 100, 10, dim=-1)
    j_plate = pyro.plate("j", 3, dim=-2)

    p_Z = pyro.sample('p_Z', dist.Beta(1, 1))
    with j_plate:
        # We need to_event to treat the 2 as an event size.
        p_X = pyro.sample('p_X', dist.Dirichlet(torch.ones(2, 3)).to_event(1))
    with data_plate:
        Z = pyro.sample('Z', dist.Bernoulli(p_Z))
        with j_plate:
            p_X_given_Z = Vindex(p_X)[..., Z.long(), :]
            pyro.sample('X', dist.Categorical(p_X_given_Z),

Now since you’re using a mean field guide (with no dependencies between variables), you could in theory replace it with say an AutoNormal guide, but to update your particular guide we could introduce a plate

def guide():
    t = pyro.param('t', 5*torch.ones(2), constraint=constraints.positive)
    pyro.sample('p_Z', dist.Beta(t[0], t[1]))
    # Note below we need a spacer dim of 1 since the data dim (of size 100) is right
    # of the j dim. It's a little more common to keep the data dim on the left.
    x = pyro.param('x', torch.ones(3,1,2,3))
    with pyro.plate("j", 3, dim=-2):
        pyro.sample('p_X', dist.Dirichlet(x).to_event(1))

I haven’t run this model, but let me know if it works :smile: One thing you might try is to see if an AutoLowRankMultivariateNormal guide improves on your custom guide, since it can model correlations among variables.

@fritzo Thanks a ton for the feedback and suggestions! Your vectorized solution works fine.
I had the data dim to the right because that was the only way I could get my code to work :slight_smile: I need to revisit the tensor shape tutorial and get a better intuition.
I do need to model the correlation among variables, I will look for examples to use AutoLowRankMultivariateNormal.

@fritzo how to map the AutoNormal (and AutoLowRankMultivariateNormal) parameters back to probability space for p_X? Thanks!

@sunil.su the AutoNormal parameters should have clear param names to help you map them back to the probability space for p_X. The AutoLowRankMultivariateNormal has more obfuscated param reshaping. However both guides provide methods .median() and .quantiles() to help consume the results, and both return a dict of samples when you call them via guide().

1 Like