[New to Pyro] Is my implementation correct?

Hi, I am new to Pyro and trying to create this model after going through a few tutorials. This example by @gbernstein helped me get started, thank you!

I have a discrete latent variable, `Z ~ Ber(t)`.
I have a set of variables `Xj ( j = (1,2,3))` dependent on Z, `P(Xj | Z) ~ Dir(a)` and `Xj` can take one of 3 values.
X is observed and let’s say, `data_x` has a shape of `[3,100]`, 100 observations for X1,X2,X3

Code below is working,

``````@pyro.infer.config_enumerate
def model():
p_Z = pyro.sample('p_Z', dist.Beta(1, 1))
# repeat 3 times for 3 Xj
p_X = pyro.sample('p_X', dist.Dirichlet(torch.ones(2,3).repeat(3,1,1)).to_event(2))

with pyro.plate('data_plate', 100, 10) as i:
obs = data_x[:,i]
Z = pyro.sample('Z', dist.Bernoulli(p_Z), infer={"enumerate": "parallel"})
for s_ind in  pyro.plate('source_plate', 3):
X = pyro.sample('X_{}'.format(s_ind), dist.Categorical(Vindex(p_X)[s_ind,Z.type(torch.long)]), obs=obs[s_ind])

def guide():
t = pyro.param('t', 5*torch.ones(2), constraint=constraints.positive)
pyro.sample('p_Z', dist.Beta(t[0], t[1]))
x = pyro.param('x', torch.ones(2,3).repeat(3,1,1))
pyro.sample('p_X', dist.Dirichlet(x).to_event(2))

loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=2)
svi = pyro.infer.SVI(model,
guide,
loss=loss_func
)
losses = []
for _ in tqdm(range(num_steps)):
loss = svi.step()
losses.append(loss)
``````
• Is this implementation right? Is there a more pyronic way of doing this?
• I have sequential plate now, How to vectorize it? I was not able to get X tensor shape agree with what was expected.
• I didn’t quite understand why p_X had to be `to_event(2)` for this to work and not `to_event(1)`
• How do I calculate `p(Z|X)` for new observation?

Hi @sunil.su,

Your model and mean field guide look good, and I think we can speed it up by moving some plates around and vectorizing the inner plate. For reference see the tensor shapes tutorial.

First I think you could convert one of the `.to_event()`s to a `pyro.plate`. Since your plates overlap, let’s create both at the start and then reuse them later. Also the `infer={"enumerate": "parallel"}` is implied by the `@pyro.infer.config_enumerate`, so we can omit it.

``````@pyro.infer.config_enumerate
def model():
data_plate = pyro.plate("data_plate", 100, 10, dim=-1)
j_plate = pyro.plate("j", 3, dim=-2)

p_Z = pyro.sample('p_Z', dist.Beta(1, 1))
with j_plate:
# We need to_event to treat the 2 as an event size.
p_X = pyro.sample('p_X', dist.Dirichlet(torch.ones(2, 3)).to_event(1))
with data_plate:
Z = pyro.sample('Z', dist.Bernoulli(p_Z))
with j_plate:
p_X_given_Z = Vindex(p_X)[..., Z.long(), :]
pyro.sample('X', dist.Categorical(p_X_given_Z),
obs=obs)
``````

Now since you’re using a mean field guide (with no dependencies between variables), you could in theory replace it with say an AutoNormal guide, but to update your particular guide we could introduce a plate

``````def guide():
t = pyro.param('t', 5*torch.ones(2), constraint=constraints.positive)
pyro.sample('p_Z', dist.Beta(t[0], t[1]))
# Note below we need a spacer dim of 1 since the data dim (of size 100) is right
# of the j dim. It's a little more common to keep the data dim on the left.
x = pyro.param('x', torch.ones(3,1,2,3))
with pyro.plate("j", 3, dim=-2):
pyro.sample('p_X', dist.Dirichlet(x).to_event(1))
``````

I haven’t run this model, but let me know if it works One thing you might try is to see if an AutoLowRankMultivariateNormal guide improves on your custom guide, since it can model correlations among variables.

@fritzo Thanks a ton for the feedback and suggestions! Your vectorized solution works fine.
I had the data dim to the right because that was the only way I could get my code to work I need to revisit the tensor shape tutorial and get a better intuition.
I do need to model the correlation among variables, I will look for examples to use AutoLowRankMultivariateNormal.

@fritzo how to map the AutoNormal (and AutoLowRankMultivariateNormal) parameters back to probability space for p_X? Thanks!

@sunil.su the `AutoNormal` parameters should have clear param names to help you map them back to the probability space for `p_X`. The `AutoLowRankMultivariateNormal` has more obfuscated param reshaping. However both guides provide methods `.median()` and `.quantiles()` to help consume the results, and both return a dict of samples when you call them via `guide()`.

1 Like