I just implemented a very simple and toy HMM model, but when I run svi.step, there was an error of NotIMplementedError. I checked the other posts that had the same problem, but the solutions to those posts seem did not suitable for my problem. Does anyone have the clue the reason for this error? Many thanks!
I think I may not use the guide or svi correctly.
Update:
I found if I warp the model with block() the problem will be disappeared. But it raises another question about how to use block(). Will the model be learned differently between the two codes below?
guide = AutoDelta(poutine.block(hmm.model, expose_fn=lambda msg: msg["name"].startswith("probs_")))
guide = AutoDelta(poutine.block(hmm.model))
If they make the model behaves differently, when should I use the expose_fn and when should not?
Original Code:
import torch
import torch.nn as nn
import pyro
from pyro import poutine
import pyro.distributions as dist
from pyro.contrib.autoguide import AutoDelta
from pyro.infer import SVI, TraceEnum_ELBO
from pyro.optim import Adam
nlabels = 3
vocabulary_size = 6
token_mini_batch = torch.IntTensor([[0,1,2,3,4,5],[5,4,3,1,2,0]])
class Transition(nn.Module):
def __init__(self, nlabels):
super(Transition, self).__init__()
self.transition = nn.Parameter(torch.zeros(nlabels, nlabels))
self.softmax = nn.Softmax(dim=1)
self.transition_0 = nn.Parameter(torch.zeros(nlabels))
self.softmax_0 = nn.Softmax(dim=0)
def forward(self):
return self.softmax_0(self.transition_0), self.softmax(self.transition)
class TokenEmission(nn.Module):
def __init__(self, nlabels, vocabulary_size):
super(TokenEmission, self).__init__()
self.emission = nn.Parameter(torch.zeros(nlabels, vocabulary_size))
self.softmax = nn.Softmax(dim=1)
def forward(self):
return self.softmax(self.emission)
class HMM(nn.Module):
def __init__(self, nlabels, vocabulary_size, use_cuda = False):
super(HMM, self).__init__()
self.transition = Transition(nlabels)
self.token_emission = TokenEmission(nlabels, vocabulary_size)
self.nlabels = nlabels
self.vocabulary_size = vocabulary_size
if use_cuda:
self.cuda()
def model(self, token_mini_batch):
T_max = token_mini_batch.size(1)
pyro.module("hmm", self)
transition_0, transition = self.transition() # nlabels * nlabels
token_emission = self.token_emission() # nlabels * vocabulary_size
sampled_transition = pyro.sample("probs_transtion",
dist.Dirichlet(transition).to_event(1))
sampled_token_emission = pyro.sample("probs_token_emission",
dist.Dirichlet(token_emission).to_event(1))
z_prev = 0
with pyro.plate("z_minibatch", len(token_mini_batch), dim=-1):
for t in range(0, T_max):
# sample z for current time step
# z_t = pyro.sample("z_%d"%t,
# dist.Categorical(transition_0.expand(token_mini_batch.size(0),self.nlabels)))
z_t = pyro.sample("z_%d"%t,
dist.Categorical(sampled_transition[z_prev]))
pyro.sample("obs_x_%d" % t,
dist.Categorical(sampled_token_emission[z_t]),
obs=token_mini_batch[:,t])
z_prev = z_t
hmm = HMM(nlabels, vocabulary_size)
hmm.model(token_mini_batch)
guide = AutoDelta(hmm.model)
elbo = TraceEnum_ELBO(max_plate_nesting=1)
optim = Adam({'lr':0.001})
svi = SVI(hmm.model, guide, optim, elbo)
loss = svi.step(token_mini_batch)
Error:
Traceback (most recent call last):
File "/home/m/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 139, in __call__
factory = self._registry[type(constraint)]
KeyError: <class 'torch.distributions.constraints._IntegerInterval'>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
ret = self.fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/contrib/autoguide/__init__.py", line 287, in __call__
constraint=site["fn"].support)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/primitives.py", line 46, in param
return _param(name, *args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 259, in _fn
apply_stack(msg)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 195, in apply_stack
default_process_message(msg)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 156, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 204, in get_param
return self.setdefault(name, init_tensor, constraint)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 160, in setdefault
self[name] = init_constrained_value
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 122, in __setitem__
unconstrained_value = transform_to(constraint).inv(new_constrained_value)
File "/home/m/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 142, in __call__
'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _IntegerInterval constraints
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
ret = self.fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/handlers.py", line 466, in _fn
return ftr(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 153, in __call__
traceback)
File "/home/m/anaconda3/lib/python3.7/site-packages/six.py", line 692, in reraise
raise value.with_traceback(tb)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
ret = self.fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/contrib/autoguide/__init__.py", line 287, in __call__
constraint=site["fn"].support)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/primitives.py", line 46, in param
return _param(name, *args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 259, in _fn
apply_stack(msg)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 195, in apply_stack
default_process_message(msg)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 156, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 204, in get_param
return self.setdefault(name, init_tensor, constraint)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 160, in setdefault
self[name] = init_constrained_value
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 122, in __setitem__
unconstrained_value = transform_to(constraint).inv(new_constrained_value)
File "/home/m/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 142, in __call__
'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _IntegerInterval constraints
Trace Shapes:
Param Sites:
auto_transtion 3 3
auto_token_emission 3 6
Sample Sites:
z_minibatch dist |
value 2 |
transtion dist | 3 3
value | 3 3
token_emission dist | 3 6
value | 3 6
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/m/GitLab/code/demo_with_toy_data.py", line 87, in <module>
loss = svi.step(token_mini_batch)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/svi.py", line 99, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py", line 365, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py", line 308, in _get_traces
yield self._get_trace(model, guide, *args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py", line 262, in _get_trace
"flat", self.max_plate_nesting, model, guide, *args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/infer/enum.py", line 42, in get_importance_trace
guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 169, in get_trace
self(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 153, in __call__
traceback)
File "/home/m/anaconda3/lib/python3.7/site-packages/six.py", line 692, in reraise
raise value.with_traceback(tb)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
ret = self.fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/handlers.py", line 466, in _fn
return ftr(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 153, in __call__
traceback)
File "/home/m/anaconda3/lib/python3.7/site-packages/six.py", line 692, in reraise
raise value.with_traceback(tb)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
ret = self.fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 27, in _wraps
return fn(*args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/contrib/autoguide/__init__.py", line 287, in __call__
constraint=site["fn"].support)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/primitives.py", line 46, in param
return _param(name, *args, **kwargs)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 259, in _fn
apply_stack(msg)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 195, in apply_stack
default_process_message(msg)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 156, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 204, in get_param
return self.setdefault(name, init_tensor, constraint)
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 160, in setdefault
self[name] = init_constrained_value
File "/home/m/anaconda3/lib/python3.7/site-packages/pyro/params/param_store.py", line 122, in __setitem__
unconstrained_value = transform_to(constraint).inv(new_constrained_value)
File "/home/m/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 142, in __call__
'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _IntegerInterval constraints
Trace Shapes:
Param Sites:
auto_transtion 3 3
auto_token_emission 3 6
Sample Sites:
z_minibatch dist |
value 2 |
transtion dist | 3 3
value | 3 3
token_emission dist | 3 6
value | 3 6
Trace Shapes:
Param Sites:
auto_transtion 3 3
auto_token_emission 3 6
Sample Sites:
z_minibatch dist |
value 2 |
transtion dist | 3 3
value | 3 3
token_emission dist | 3 6
value | 3 6
Process finished with exit code 1