NotImplementedError: Cannot transform _IntegerInterval constraints on a very simple HMM model

I just implemented a very simple and toy HMM model, but when I run svi.step, there was an error of NotIMplementedError. I checked the other posts that had the same problem, but the solutions to those posts seem did not suitable for my problem. Does anyone have the clue the reason for this error? Many thanks!

I think I may not use the guide or svi correctly.
Update:
I found if I warp the model with block() the problem will be disappeared. But it raises another question about how to use block(). Will the model be learned differently between the two codes below?

guide = AutoDelta(poutine.block(hmm.model, expose_fn=lambda msg: msg["name"].startswith("probs_")))

guide = AutoDelta(poutine.block(hmm.model))

If they make the model behaves differently, when should I use the expose_fn and when should not?

Original Code:

import torch
import torch.nn as nn

import pyro
from pyro import poutine
import pyro.distributions as dist
from pyro.contrib.autoguide import AutoDelta
from pyro.infer import SVI, TraceEnum_ELBO
from pyro.optim import Adam

nlabels = 3
vocabulary_size = 6

token_mini_batch = torch.IntTensor([[0,1,2,3,4,5],[5,4,3,1,2,0]])

class Transition(nn.Module):
    def __init__(self, nlabels):
        super(Transition, self).__init__()
        self.transition = nn.Parameter(torch.zeros(nlabels, nlabels))
        self.softmax = nn.Softmax(dim=1)

        self.transition_0 = nn.Parameter(torch.zeros(nlabels))
        self.softmax_0 = nn.Softmax(dim=0)

    def forward(self):
        return self.softmax_0(self.transition_0), self.softmax(self.transition)

class TokenEmission(nn.Module):
    def __init__(self, nlabels, vocabulary_size):
        super(TokenEmission, self).__init__()
        self.emission = nn.Parameter(torch.zeros(nlabels, vocabulary_size))

        self.softmax = nn.Softmax(dim=1)

    def forward(self):
        return self.softmax(self.emission)

class HMM(nn.Module):
    def __init__(self, nlabels, vocabulary_size, use_cuda = False):
        super(HMM, self).__init__()
        self.transition = Transition(nlabels)
        self.token_emission = TokenEmission(nlabels, vocabulary_size)

        self.nlabels = nlabels
        self.vocabulary_size = vocabulary_size

        if use_cuda:
            self.cuda()

    def model(self, token_mini_batch):
        T_max = token_mini_batch.size(1)
        pyro.module("hmm", self)

        transition_0, transition = self.transition()  # nlabels * nlabels
        token_emission = self.token_emission()  # nlabels * vocabulary_size

        sampled_transition = pyro.sample("probs_transtion",
                                 dist.Dirichlet(transition).to_event(1))
        sampled_token_emission = pyro.sample("probs_token_emission",
                                     dist.Dirichlet(token_emission).to_event(1))

        z_prev = 0
        with pyro.plate("z_minibatch", len(token_mini_batch), dim=-1):
            for t in range(0, T_max):
                # sample z for current time step
                # z_t = pyro.sample("z_%d"%t,
                #                   dist.Categorical(transition_0.expand(token_mini_batch.size(0),self.nlabels)))

                z_t = pyro.sample("z_%d"%t,
                                  dist.Categorical(sampled_transition[z_prev]))

                pyro.sample("obs_x_%d" % t,
                            dist.Categorical(sampled_token_emission[z_t]),
                            obs=token_mini_batch[:,t])

                z_prev = z_t

hmm = HMM(nlabels, vocabulary_size)

hmm.model(token_mini_batch)

guide = AutoDelta(hmm.model)
elbo = TraceEnum_ELBO(max_plate_nesting=1)
optim = Adam({'lr':0.001})
svi = SVI(hmm.model, guide, optim, elbo)

loss = svi.step(token_mini_batch)

Error:

Traceback (most recent call last):
  File "/home/m/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 139, in __call__
    factory = self._registry[type(constraint)]
KeyError: <class 'torch.distributions.constraints._IntegerInterval'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/contrib/autoguide/__init__.py", line 287, in __call__
    constraint=site["fn"].support)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/primitives.py", line 46, in param
    return _param(name, *args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 259, in _fn
    apply_stack(msg)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 195, in apply_stack
    default_process_message(msg)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 156, in default_process_message
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 204, in get_param
    return self.setdefault(name, init_tensor, constraint)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 160, in setdefault
    self[name] = init_constrained_value
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 122, in __setitem__
    unconstrained_value = transform_to(constraint).inv(new_constrained_value)
  File "/home/m/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 142, in __call__
    'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _IntegerInterval constraints

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/handlers.py", line 466, in _fn
    return ftr(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 153, in __call__
    traceback)
  File "/home/m/anaconda3/lib/python3.7/site-packages/six.py", line 692, in reraise
    raise value.with_traceback(tb)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/contrib/autoguide/__init__.py", line 287, in __call__
    constraint=site["fn"].support)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/primitives.py", line 46, in param
    return _param(name, *args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 259, in _fn
    apply_stack(msg)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 195, in apply_stack
    default_process_message(msg)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 156, in default_process_message
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 204, in get_param
    return self.setdefault(name, init_tensor, constraint)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 160, in setdefault
    self[name] = init_constrained_value
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 122, in __setitem__
    unconstrained_value = transform_to(constraint).inv(new_constrained_value)
  File "/home/m/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 142, in __call__
    'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _IntegerInterval constraints
      Trace Shapes:        
       Param Sites:        
     auto_transtion 3 3    
auto_token_emission 3 6    
      Sample Sites:        
   z_minibatch dist   |    
              value 2 |    
     transtion dist   | 3 3
              value   | 3 3
token_emission dist   | 3 6
              value   | 3 6

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/m/GitLab/code/demo_with_toy_data.py", line 87, in <module>
    loss = svi.step(token_mini_batch)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/svi.py", line 99, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py", line 365, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py", line 308, in _get_traces
    yield self._get_trace(model, guide, *args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py", line 262, in _get_trace
    "flat", self.max_plate_nesting, model, guide, *args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/enum.py", line 42, in get_importance_trace
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 169, in get_trace
    self(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 153, in __call__
    traceback)
  File "/home/m/anaconda3/lib/python3.7/site-packages/six.py", line 692, in reraise
    raise value.with_traceback(tb)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/handlers.py", line 466, in _fn
    return ftr(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 153, in __call__
    traceback)
  File "/home/m/anaconda3/lib/python3.7/site-packages/six.py", line 692, in reraise
    raise value.with_traceback(tb)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
    return fn(*args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/contrib/autoguide/__init__.py", line 287, in __call__
    constraint=site["fn"].support)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/primitives.py", line 46, in param
    return _param(name, *args, **kwargs)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 259, in _fn
    apply_stack(msg)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 195, in apply_stack
    default_process_message(msg)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 156, in default_process_message
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 204, in get_param
    return self.setdefault(name, init_tensor, constraint)
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 160, in setdefault
    self[name] = init_constrained_value
  File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 122, in __setitem__
    unconstrained_value = transform_to(constraint).inv(new_constrained_value)
  File "/home/m/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 142, in __call__
    'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _IntegerInterval constraints
      Trace Shapes:        
       Param Sites:        
     auto_transtion 3 3    
auto_token_emission 3 6    
      Sample Sites:        
   z_minibatch dist   |    
              value 2 |    
     transtion dist   | 3 3
              value   | 3 3
token_emission dist   | 3 6
              value   | 3 6
      Trace Shapes:        
       Param Sites:        
     auto_transtion 3 3    
auto_token_emission 3 6    
      Sample Sites:        
   z_minibatch dist   |    
              value 2 |    
     transtion dist   | 3 3
              value   | 3 3
token_emission dist   | 3 6
              value   | 3 6

Process finished with exit code 1

that error is due to the fact that you have a categorical distribution which has discrete support. i suggest enumerating out the discrete variables in your model.

Thank you very much for your help! But after reading the tutorial, I am still a bit confused.

For example, in the gaussian mixture model below, I noticed that the site “obs” and “assignment” are hidden, i.e. as shown in the code, poutine.block, the “weights”, “loc” and “scale” are expose.

So does that mean the hiddened ones are enumerated out the discrete variables here? (I don’t think my understand is correct actually, but I cannot find anything else to explain the code here and how the code enumerates the discrete variables. Also, the variable “obs” is not discrete.)


K = 2  # Fixed number of components.

@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']))

nope. the config_enumerate decorator used in conjunction with TraceEnum_ELBO will enumerate out the discrete variables. AutoGuide however, does not have knowledge as to which variables you’ve enumerated out. therefore, you expose or hide variables that you want a guide distribution for. in your snippet above, assignment will be enumerated out so it is hidden from Autoguide via expose.

Hi, thank you so much for your nice reply. Now I can understand somehow. So the expose or hide is used to indicate which variables I want or do not want guide distributions. That is why if I implement my own guide distribution I will not need to use expose or hide.

Thus in my case HMM, I should put the @config_enumerate before the model function in order to enumerate out the latent state _z_t_and expose the transition and emission probabilities if I use the AutoGuide (i.e. without my own implemented guide).

I would like to ask one more question.

Is it compulsive to use @config_enumerate here? In other words, is this setting a default setting? Because I did not find it in the hmm tutorial code. I am guessing maybe the config_enumerate is the default?

So the expose or hide is used to indicate which variables I want or do not want guide distributions. That is why if I implement my own guide distribution I will not need to use expose or hide.

correct

In other words, is this setting a default setting

no, you must add the decorator for each model you want to enumerate.

Because I did not find it in the hmm tutorial code

the tutorial has manual annotations infer={"enumerate": "parallel"} at sample sites that require enumeration. config_enumerate simply does that automatically for all discrete variables in the wrapped model.

Hi jpchen, thank you so much! Now it is very clear! :grinning: