Hi everyone,
I apologize if these are trivial questions, but so far pyro is really just an inscrutable black box to me and I am having serious trouble debugging usefully when I encounter any problem.
I am currently trying to get a supposedly simple Gaussian mixture model with Bernoulli observations to work, with variable enumeration. I have three versions of a model, all of which I feel should work, but only one of them does. The other two throw two different exceptions. I’d greatly appreciate any help in understanding why the latter two versions fail and how to fix that!
IMPORTS; TOY DATA
import torch
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import TraceEnum_ELBO, config_enumerate
from pyro.infer.autoguide import AutoDiagonalNormal
X_train = torch.randn(20, 2)
y_train = torch.bernoulli(torch.ones(20)*0.5)
X_test = torch.randn(10, 2)
y_test = torch.bernoulli(torch.ones(10)*0.5)
VERSION 1, WORKS
Not that it should be very relevant for why the model fails below, but the general idea is to sample (x-dependent) probability logits from a GMM, and then draw observations (y) from a Bernoulli distribution based on those logits. It’s a risk model in the context of binary observations. In this first version, the means of the different GMM components are different but they all share the same scale. I have adapted this from the GMM tutorial.
@config_enumerate
def model(X, y, num_components=3):
intercept = pyro.sample("intercept", dist.Normal(0.0, 5.0))
beta = pyro.sample("beta", dist.Normal(0.0, 5.0))
risk_logits = intercept + X[:, 0] * beta
weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(num_components)))
scale = pyro.sample("scale", dist.LogNormal(0, num_components))
with pyro.plate('components', num_components):
locs = pyro.sample('locs', dist.Normal(0., 0.5))
with pyro.plate("data", y.shape[0]):
assignment = pyro.sample('assignment', dist.Categorical(weights))
risk_noise = pyro.sample('risk_noise', dist.Normal(locs[assignment], scale))
observation = pyro.sample("obs", dist.Bernoulli(logits=risk_logits + risk_noise), obs=y)
guide = AutoDiagonalNormal(poutine.block(model, hide=['assignment']))
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide, X_train, y_train); # runs through
VERSION 2, FAILS
This is exactly the same as above, just that I have component-specific scales in addition to the component-specific locations that we already had above.
@config_enumerate
def model2(X, y, num_components=3):
intercept = pyro.sample("intercept", dist.Normal(0.0, 5.0))
beta = pyro.sample("beta", dist.Normal(0.0, 5.0))
risk_logits = intercept + X[:, 0] * beta
weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(num_components)))
with pyro.plate('components', num_components):
locs = pyro.sample('locs', dist.Normal(0., 0.5))
scales = pyro.sample('scales', dist.LogNormal(0., num_components)) # now we have component-wise scales as well
with pyro.plate("data", y.shape[0]):
assignment = pyro.sample('assignment', dist.Categorical(weights))
risk_noise = pyro.sample('risk_noise', dist.Normal(locs[assignment], scales[assignment])) # component-wise scale
observation = pyro.sample("obs", dist.Bernoulli(logits=risk_logits + risk_noise), obs=y)
guide2 = AutoDiagonalNormal(poutine.block(model2, hide=['assignment']))
elbo2 = TraceEnum_ELBO(max_plate_nesting=1)
elbo2.loss(model2, guide2, X_train, y_train); # fails
Fails with RuntimeError: shape '[20]' is invalid for input of size 18
and prints these trace shapes:
Trace Shapes:
Param Sites:
AutoDiagonalNormal.loc 28
AutoDiagonalNormal.scale 28
Sample Sites:
_AutoDiagonalNormal_latent dist | 28
value | 28
components dist |
value 3 |
data dist |
value 20 |
intercept dist |
value |
beta dist |
value |
weights dist | 3
value | 3
locs dist 3 |
value 3 |
scales dist 3 |
value 3 |
Trace Shapes:
Param Sites:
AutoDiagonalNormal.loc 28
AutoDiagonalNormal.scale 28
Sample Sites:
_AutoDiagonalNormal_latent dist | 28
value | 28
components dist |
value 3 |
data dist |
value 20 |
intercept dist |
value |
beta dist |
value |
weights dist | 3
value | 3
locs dist 3 |
value 3 |
scales dist 3 |
value 3 |
VERSION 3; FAILS DIFFERENTLY
This is exactly the same as version 1 above, just that we have the observations on the Normal distribution directly instead of on the additional Bernoulli distribution we had in version 1. (This doesn’t make a lot of sense since we have binary observations; it’s just for debugging purposes.)
@config_enumerate
def model3(X, y, num_components=3):
intercept = pyro.sample("intercept", dist.Normal(0.0, 5.0))
beta = pyro.sample("beta", dist.Normal(0.0, 5.0))
risk_logits = intercept + X[:, 0] * beta
weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(num_components)))
scale = pyro.sample("scale", dist.LogNormal(0, num_components))
with pyro.plate('components', num_components):
locs = pyro.sample('locs', dist.Normal(0., 0.5))
with pyro.plate("data", y.shape[0]):
assignment = pyro.sample('assignment', dist.Categorical(weights))
risk_noise = pyro.sample('risk_noise', dist.Normal(locs[assignment], scale), obs=y) # observations now here
# no Bernoulli observations anymore
guide3 = AutoDiagonalNormal(poutine.block(model3, hide=['assignment']))
elbo3 = TraceEnum_ELBO(max_plate_nesting=1)
elbo3.loss(model3, guide3, X_train, y_train); # fails
Fails with a not very informative AssertionError
related to the AutoGuide:
File ~\Anaconda3\envs\dst\lib\site-packages\pyro\infer\autoguide\guides.py:745, in AutoContinuous._unpack_latent(self, latent)
743 pos += size
744 if not torch._C._get_tracing_state():
--> 745 assert pos == latent.size(-1)
I would really greatly appreciate any help in understanding what is going on here, and also to learn about general strategies that can help me debug this type of issue myself in the future.