Hi!
I continue to build HMM model.
There is a problem with simplest test data.
Code
import numpy as np
from modules.stuff import *
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 100
n_steps_print = 10 ** ( np.int( np.log10( n_steps ) ) - 1 )
assert pyro.__version__.startswith('1.3.0')
pyro.set_rng_seed(101)
pyro.enable_validation(True)
torch.set_default_tensor_type( 'torch.cuda.FloatTensor' )
pyro.clear_param_store()
data_3 = [ [ 0 ] * 235 ] * 2
data_3 = torch.tensor( data_3, dtype=torch.float64 )
args_3 = ( data_3, )
def model_3( data ) :
hmmdim = 2
datadim = 1
x_q = pyro.param( "x_q"
, torch.eye( hmmdim ) * .1 + .9
, constraint=constraints.simplex )
y_q = pyro.param( "y_q"
, torch.ones( [ hmmdim, datadim ] )
, constraint=constraints.simplex )
x = 0
plate_seq = pyro.plate( "seq", len( data[0] ) )
for j in pyro.markov( range( len( data ) ) ) :
x = pyro.sample( "x_{}".format( j )
, dist.Categorical( Vindex( x_q )[..., x, :] )
, infer={ "enumerate": "parallel" } )
with plate_seq:
z = pyro.sample( "obs_{}".format(j)
, dist.Categorical( Vindex( y_q )[..., x, :] )
, obs=data[j] )
model_number = 3
args = globals()["args_" + str( model_number )]
model = globals()["model_" + str( model_number )]
guide = AutoDelta( poutine.block( model, expose_fn=lambda msg: msg["name"].endswith( "_q" ) ) )
adam_params = { "lr": 0.05, "betas": ( 0.90, 0.999 ) }
optimizer = Adam( adam_params )
svi = SVI( model, guide, optimizer, loss=JitTraceEnum_ELBO() )
for step in range( n_steps ) :
loss = svi.step( *args )
if step % n_steps_print == 0:
print(' . %.3f' % loss, end='')
print()
for k, v in sorted( pyro.get_param_store().items() ) :
b = v.cpu().detach().numpy()
print( k )
print( b.shape )
print( b.round( 2 ) )
Trace
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-3-1cf6e6a33278> in <module>
69
70 for step in range( n_steps ) :
---> 71 loss = svi.step( *args )
72 if step % n_steps_print == 0:
73 print(' . %.3f' % loss, end='')
~/src/envpgran/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
126 # get loss and compute gradients
127 with poutine.trace(param_only=True) as param_capture:
--> 128 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
129
130 params = set(site["value"].unconstrained()
~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
497
498 def loss_and_grads(self, model, guide, *args, **kwargs):
--> 499 differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
500 differentiable_loss.backward() # this line triggers jit compilation
501 loss = differentiable_loss.item()
~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in differentiable_loss(self, model, guide, *args, **kwargs)
494 self._differentiable_loss = differentiable_loss
495
--> 496 return self._differentiable_loss(*args, **kwargs)
497
498 def loss_and_grads(self, model, guide, *args, **kwargs):
~/src/envpgran/lib/python3.6/site-packages/pyro/ops/jit.py in __call__(self, *args, **kwargs)
93 time_compilation = self.jit_options.pop("time_compilation", False)
94 with optional(timed(), time_compilation) as t:
---> 95 self.compiled[key] = torch.jit.trace(compiled, params_and_args, **self.jit_options)
96 if time_compilation:
97 self.compile_time = t.elapsed
~/src/envpgran/lib/python3.6/site-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)
904 traced = torch._C._create_function_from_trace(name, func, example_inputs,
905 var_lookup_fn,
--> 906 _force_outplace)
907
908 # Check the trace against new traces created from user-specified inputs
~/src/envpgran/lib/python3.6/site-packages/pyro/ops/jit.py in compiled(*params_and_args)
86 assert constrained_param.unconstrained() is unconstrained_param
87 constrained_params[name] = constrained_param
---> 88 return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs)
89
90 if self.ignore_warnings:
~/src/envpgran/lib/python3.6/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in differentiable_loss(*args, **kwargs)
489 elbo = 0.0
490 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
--> 491 elbo = elbo + _compute_dice_elbo(model_trace, guide_trace)
492 return elbo * (-1.0 / self.num_particles)
493
~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in _compute_dice_elbo(model_trace, guide_trace)
149 # Accumulate marginal model costs.
150 marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
--> 151 model_trace, guide_trace)
152 if log_factors:
153 dim_to_size = {}
~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in _compute_model_factors(model_trace, guide_trace)
142 log_factors.setdefault(t, []).append(logprob)
143 scales.append(site["scale"])
--> 144 scale = _get_common_scale(scales)
145 return marginal_costs, log_factors, ordering, enum_dims, scale
146
/usr/lib/python3.6/contextlib.py in inner(*args, **kwds)
50 def inner(*args, **kwds):
51 with self._recreate_cm():
---> 52 return func(*args, **kwds)
53 return inner
54
~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in _get_common_scale(scales)
38 if len(scales_set) != 1:
39 raise ValueError("Expected all enumerated sample sites to share a common poutine.scale, "
---> 40 "but found {} different scales.".format(len(scales_set)))
41 return scales[0]
42
ValueError: Expected all enumerated sample sites to share a common poutine.scale, but found 2 different scales.
I have no chance to figure out from the trace, neither from docs, what is wrong. Any suggestions, please?
It is “minimal example”.
The changing hmmdim
to 1 fix the error.