Infer_discrete with multiple sites and enumeration

I’m having trouble with a hierarchical mixture model, whose posterior inference is performed with autoguide and naive enumeration (I’m not confident about whether I’m enumerating properly). I’m showing a toy example that should be reproduceable, and raises the same error.

def toy_model(data=None, n1=3, n2=5):
    p1 = pyro.param("p1", torch.randn(n1).exp(), constraint=constraints.simplex)
    p2 = pyro.param("p2", torch.randn(n2).exp(), constraint=constraints.simplex)

    with pyro.plate("level1", n1):
        mu1 = pyro.sample("mu1", Normal(0, 1))

    with pyro.plate("level2", n2):
        z1 = pyro.sample("z1", Categorical(p1), infer={'enumerate': 'parallel'})
        mu2 = pyro.sample("mu2", Normal(mu1[z1], 0.1))

    with pyro.plate("data", N):
        z2 = pyro.sample("z2", Categorical(p2), infer={'enumerate': 'parallel'}) #, infer={'enumerate': 'parallel'}
        pyro.sample("obs", Normal(mu2[z2], 0.01), obs=data)

model = toy_model

And the graph illustrates the hierarchical structure:

optim = ClippedAdam({'lr': 0.1, 'betas': [0.9, 0.999], 'lrd': 0.999})
elbo = pyro.infer.TraceEnum_ELBO(num_particles=4, max_plate_nesting=1)
def train(num_iterations, losses=[]):
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data) / N
    return losses

def initialize(seed):
    global global_guide, svi

    global_guide = pyro.infer.autoguide.AutoDelta(
        poutine.block(model, hide=['z1', 'z2']),
    svi = SVI(model, global_guide, optim, loss=elbo)
    return svi.loss(model, global_guide, data) / N

loss, seed = min((initialize(seed), seed) for seed in tqdm(range(100)))
print('seed = {}, initial_loss = {}'.format(seed, loss))

losses = []
losses = train(200, losses=losses)

ELBO loss converges:

Here comes the problem: when I’m directly following the patterns here to perform MAP inference for discrete sites

guide_trace = poutine.trace(global_guide).get_trace(data)  # record the globals
trained_model = poutine.replay(model, trace=guide_trace)  # replay the globals

def classifier(data, temperature=0): #random assignment if temperature=1
    inferred_model = infer_discrete(trained_model, temperature=temperature,
                                    first_available_dim=-2)  # avoid conflict with data plate
    trace = poutine.trace(inferred_model).get_trace(data)
    return trace.nodes["z2"]["value"]

I got an error

ValueError: Error while packing tensors at site 'mu2':
  Invalid tensor shape.
  Allowed dims: -1
  Actual shape: (3, 5)
  Try adding shape assertions for your model's sample values and distribution parameters.
Trace Shapes:      
 Param Sites:      
           p1     3
           p2     5
Sample Sites:      
     mu1 dist   3 |
        value   3 |
     log_prob   3 |
      z1 dist   5 |
        value 3 1 |
     log_prob 3 5 |
     mu2 dist 3 5 |
        value   5 |
     log_prob 3 5 |
Trace Shapes:
 Param Sites:
Sample Sites:

The infer_discrete and classifier worked well for programs with only one discrete site, and after checking the shape, I thought the error should be related with enumeration. How should I deal with this issue?

Perhaps any tips on such kind of indexing? Seems that I often get into trouble during enumeration :sob:

The error is about packing sites:

e:\00NSFC-categorization\00_experiment-Categorization\B02_REFRESH.ipynb Cell 21 in classifier(data, temperature)
      5 def classifier(data, temperature=0): #random assignment if temperature=1
      6     inferred_model = infer_discrete(trained_model, temperature=temperature,
      7                                     first_available_dim=-2)  # avoid conflict with data plate
----> 8     trace = poutine.trace(inferred_model).get_trace(data)
      9     return trace.nodes["z"]["value"]

File c:\Users\19046\.conda\envs\pytorch\lib\site-packages\pyro\poutine\, in TraceHandler.get_trace(self, *args, **kwargs)
    190 def get_trace(self, *args, **kwargs):
    191     """
    192     :returns: data structure
    193     :rtype: pyro.poutine.Trace
    196     Calls this poutine and returns its trace instead of the function's return value.
    197     """
--> 198     self(*args, **kwargs)
    199     return self.msngr.get_trace()

File c:\Users\19046\.conda\envs\pytorch\lib\site-packages\pyro\poutine\, in TraceHandler.__call__(self, *args, **kwargs)
    178         exc = exc_type("{}\n{}".format(exc_value, shapes))
    179         exc = exc.with_traceback(traceback)
--> 180         raise exc from e
    181     self.msngr.trace.add_node(
    182         "_RETURN", name="_RETURN", type="return", value=ret
    183     )
    184 return ret

File c:\Users\19046\.conda\envs\pytorch\lib\site-packages\pyro\poutine\, in TraceHandler.__call__(self, *args, **kwargs)
    170 self.msngr.trace.add_node(
    171     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
    172 )
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:
    176     exc_type, exc_value, traceback = sys.exc_info()

File c:\Users\19046\.conda\envs\pytorch\lib\site-packages\pyro\infer\, in _sample_posterior(model, first_available_dim, temperature, strict_enumeration_warning, *args, **kwargs)
     49 enum_trace = prune_subsample_sites(enum_trace)
     50 enum_trace.compute_log_prob()
---> 51 enum_trace.pack_tensors()
     53 return _sample_posterior_from_trace(
     54     model, enum_trace, temperature, strict_enumeration_warning, *args, **kwargs
     55 )

File c:\Users\19046\.conda\envs\pytorch\lib\site-packages\pyro\poutine\, in Trace.pack_tensors(self, plate_to_symbol)
    427 _, exc_value, traceback = sys.exc_info()
    428 shapes = self.format_shapes(last_site=site["name"])
--> 429 raise ValueError(
    430     "Error while packing tensors at site '{}':\n  {}\n{}".format(
    431         site["name"], exc_value, shapes
    432     )
    433 ).with_traceback(traceback) from e

The most confusing part is here: It seems OK to have z1 site enumerated, but the downstream shape covaries with it, and I think this is the cause. Sadly I have no idea how to fix this. Could any one help me please? I’ve been struck here for some weeks :broken_heart:

I’ve hit the same bug couple of years ago :wink:

There is a solution that I tried that worked for me. Maybe it will work for you as well :slight_smile:

1 Like

Oh, I see! :hushed: I’ll try it out. Thanks a lot! :laughing: