Basic questions / error messages / variable enumeration

Hi everyone,

I apologize if these are trivial questions, but so far pyro is really just an inscrutable black box to me and I am having serious trouble debugging usefully when I encounter any problem.

I am currently trying to get a supposedly simple Gaussian mixture model with Bernoulli observations to work, with variable enumeration. I have three versions of a model, all of which I feel should work, but only one of them does. The other two throw two different exceptions. I’d greatly appreciate any help in understanding why the latter two versions fail and how to fix that!

IMPORTS; TOY DATA

import torch
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import TraceEnum_ELBO, config_enumerate
from pyro.infer.autoguide import AutoDiagonalNormal
X_train = torch.randn(20, 2)
y_train = torch.bernoulli(torch.ones(20)*0.5)
X_test = torch.randn(10, 2)
y_test = torch.bernoulli(torch.ones(10)*0.5)

VERSION 1, WORKS
Not that it should be very relevant for why the model fails below, but the general idea is to sample (x-dependent) probability logits from a GMM, and then draw observations (y) from a Bernoulli distribution based on those logits. It’s a risk model in the context of binary observations. In this first version, the means of the different GMM components are different but they all share the same scale. I have adapted this from the GMM tutorial.

@config_enumerate
def model(X, y, num_components=3):
    
    intercept = pyro.sample("intercept", dist.Normal(0.0, 5.0))
    beta = pyro.sample("beta", dist.Normal(0.0, 5.0))
    risk_logits = intercept + X[:, 0] * beta
    
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(num_components)))
    scale = pyro.sample("scale", dist.LogNormal(0, num_components))
    with pyro.plate('components', num_components):
        locs = pyro.sample('locs', dist.Normal(0., 0.5))
        
    with pyro.plate("data", y.shape[0]):
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        risk_noise = pyro.sample('risk_noise', dist.Normal(locs[assignment], scale))
        observation = pyro.sample("obs", dist.Bernoulli(logits=risk_logits + risk_noise), obs=y)
        
guide = AutoDiagonalNormal(poutine.block(model, hide=['assignment']))

elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide, X_train, y_train);   # runs through

VERSION 2, FAILS
This is exactly the same as above, just that I have component-specific scales in addition to the component-specific locations that we already had above.

@config_enumerate
def model2(X, y, num_components=3):
    
    intercept = pyro.sample("intercept", dist.Normal(0.0, 5.0))
    beta = pyro.sample("beta", dist.Normal(0.0, 5.0))
    risk_logits = intercept + X[:, 0] * beta
    
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(num_components)))
    with pyro.plate('components', num_components):
        locs = pyro.sample('locs', dist.Normal(0., 0.5))
        scales = pyro.sample('scales', dist.LogNormal(0., num_components)) # now we have component-wise scales as well
        
    with pyro.plate("data", y.shape[0]):
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        risk_noise = pyro.sample('risk_noise', dist.Normal(locs[assignment], scales[assignment])) # component-wise scale
        observation = pyro.sample("obs", dist.Bernoulli(logits=risk_logits + risk_noise), obs=y)
        
guide2 = AutoDiagonalNormal(poutine.block(model2, hide=['assignment']))

elbo2 = TraceEnum_ELBO(max_plate_nesting=1)
elbo2.loss(model2, guide2, X_train, y_train); # fails

Fails with RuntimeError: shape '[20]' is invalid for input of size 18 and prints these trace shapes:

                  Trace Shapes:         
                   Param Sites:         
         AutoDiagonalNormal.loc    28   
       AutoDiagonalNormal.scale    28   
                  Sample Sites:         
_AutoDiagonalNormal_latent dist     | 28
                          value     | 28
                components dist     |   
                          value  3  |   
                      data dist     |   
                          value 20  |   
                 intercept dist     |   
                          value     |   
                      beta dist     |   
                          value     |   
                   weights dist     |  3
                          value     |  3
                      locs dist  3  |   
                          value  3  |   
                    scales dist  3  |   
                          value  3  |   
                  Trace Shapes:         
                   Param Sites:         
         AutoDiagonalNormal.loc    28   
       AutoDiagonalNormal.scale    28   
                  Sample Sites:         
_AutoDiagonalNormal_latent dist     | 28
                          value     | 28
                components dist     |   
                          value  3  |   
                      data dist     |   
                          value 20  |   
                 intercept dist     |   
                          value     |   
                      beta dist     |   
                          value     |   
                   weights dist     |  3
                          value     |  3
                      locs dist  3  |   
                          value  3  |   
                    scales dist  3  |   
                          value  3  |   

VERSION 3; FAILS DIFFERENTLY
This is exactly the same as version 1 above, just that we have the observations on the Normal distribution directly instead of on the additional Bernoulli distribution we had in version 1. (This doesn’t make a lot of sense since we have binary observations; it’s just for debugging purposes.)

@config_enumerate
def model3(X, y, num_components=3):
    
    intercept = pyro.sample("intercept", dist.Normal(0.0, 5.0))
    beta = pyro.sample("beta", dist.Normal(0.0, 5.0))
    risk_logits = intercept + X[:, 0] * beta
    
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(num_components)))
    scale = pyro.sample("scale", dist.LogNormal(0, num_components))
    with pyro.plate('components', num_components):
        locs = pyro.sample('locs', dist.Normal(0., 0.5))
        
    with pyro.plate("data", y.shape[0]):
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        risk_noise = pyro.sample('risk_noise', dist.Normal(locs[assignment], scale), obs=y) # observations now here
        # no Bernoulli observations anymore
        
guide3 = AutoDiagonalNormal(poutine.block(model3, hide=['assignment']))

elbo3 = TraceEnum_ELBO(max_plate_nesting=1)
elbo3.loss(model3, guide3, X_train, y_train);   # fails

Fails with a not very informative AssertionError related to the AutoGuide:

File ~\Anaconda3\envs\dst\lib\site-packages\pyro\infer\autoguide\guides.py:745, in AutoContinuous._unpack_latent(self, latent)
    743     pos += size
    744 if not torch._C._get_tracing_state():
--> 745     assert pos == latent.size(-1)

I would really greatly appreciate any help in understanding what is going on here, and also to learn about general strategies that can help me debug this type of issue myself in the future.

Can you try splitting models into separate files or just add pyro.clear_param_store() before you run a new elbo loss:

pyro.clear_param_store()
elbo2.loss(model2, guide2, X_train, y_train)

Pyro parameters are stored in the global state and needs to be cleared if you want to run inference on another model in the same script.

1 Like

Ha, now I feel incredibly stupid. Thanks!! That fixed everything. :slight_smile: