HMM with simple data error: Expected all enumerated sample sites to share a common poutine.scale

Hi!
I continue to build HMM model.
There is a problem with simplest test data.

Code
import numpy as np

from modules.stuff import *

import math
import os
import torch

import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO
import pyro.distributions as dist

from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 100
n_steps_print = 10 ** ( np.int( np.log10( n_steps ) ) - 1 )

assert pyro.__version__.startswith('1.3.0')

pyro.set_rng_seed(101)

pyro.enable_validation(True)

torch.set_default_tensor_type( 'torch.cuda.FloatTensor' )

pyro.clear_param_store()

data_3 = [ [ 0 ] * 235 ] * 2
data_3 = torch.tensor( data_3, dtype=torch.float64 )

args_3 = ( data_3, )

def model_3( data ) :
    hmmdim = 2
    datadim = 1
    x_q = pyro.param( "x_q"
                     , torch.eye( hmmdim ) * .1 + .9
                     , constraint=constraints.simplex )
    y_q = pyro.param(  "y_q"
                     , torch.ones( [ hmmdim, datadim ] )
                     , constraint=constraints.simplex )
    x = 0
    plate_seq = pyro.plate( "seq", len( data[0] ) )
    for j in pyro.markov( range( len( data ) ) ) :
        x = pyro.sample(  "x_{}".format( j )
                        , dist.Categorical( Vindex( x_q )[..., x, :] )
                        , infer={ "enumerate": "parallel" } )
        with plate_seq:
            z = pyro.sample( "obs_{}".format(j)
                       , dist.Categorical( Vindex( y_q )[..., x, :] )
                       , obs=data[j] )

model_number = 3

args = globals()["args_" + str( model_number )]
model = globals()["model_" + str( model_number )]

guide = AutoDelta( poutine.block( model, expose_fn=lambda msg: msg["name"].endswith( "_q" ) ) )

adam_params = { "lr": 0.05, "betas": ( 0.90, 0.999 ) }
optimizer = Adam( adam_params )

svi = SVI( model, guide, optimizer, loss=JitTraceEnum_ELBO() )

for step in range( n_steps ) :
    loss = svi.step( *args )
    if step % n_steps_print == 0:
        print(' . %.3f' % loss, end='')
print()

for k, v in sorted( pyro.get_param_store().items() ) :
    b = v.cpu().detach().numpy()
    print( k )
    print( b.shape )
    print( b.round( 2 ) )
Trace
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-3-1cf6e6a33278> in <module>
     69 
     70 for step in range( n_steps ) :
---> 71     loss = svi.step( *args )
     72     if step % n_steps_print == 0:
     73         print(' . %.3f' % loss, end='')

~/src/envpgran/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    126         # get loss and compute gradients
    127         with poutine.trace(param_only=True) as param_capture:
--> 128             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    129 
    130         params = set(site["value"].unconstrained()

~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    497 
    498     def loss_and_grads(self, model, guide, *args, **kwargs):
--> 499         differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
    500         differentiable_loss.backward()  # this line triggers jit compilation
    501         loss = differentiable_loss.item()

~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in differentiable_loss(self, model, guide, *args, **kwargs)
    494             self._differentiable_loss = differentiable_loss
    495 
--> 496         return self._differentiable_loss(*args, **kwargs)
    497 
    498     def loss_and_grads(self, model, guide, *args, **kwargs):

~/src/envpgran/lib/python3.6/site-packages/pyro/ops/jit.py in __call__(self, *args, **kwargs)
     93                 time_compilation = self.jit_options.pop("time_compilation", False)
     94                 with optional(timed(), time_compilation) as t:
---> 95                     self.compiled[key] = torch.jit.trace(compiled, params_and_args, **self.jit_options)
     96                 if time_compilation:
     97                     self.compile_time = t.elapsed

~/src/envpgran/lib/python3.6/site-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)
    904     traced = torch._C._create_function_from_trace(name, func, example_inputs,
    905                                                   var_lookup_fn,
--> 906                                                   _force_outplace)
    907 
    908     # Check the trace against new traces created from user-specified inputs

~/src/envpgran/lib/python3.6/site-packages/pyro/ops/jit.py in compiled(*params_and_args)
     86                     assert constrained_param.unconstrained() is unconstrained_param
     87                     constrained_params[name] = constrained_param
---> 88                 return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs)
     89 
     90             if self.ignore_warnings:

~/src/envpgran/lib/python3.6/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in differentiable_loss(*args, **kwargs)
    489                 elbo = 0.0
    490                 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
--> 491                     elbo = elbo + _compute_dice_elbo(model_trace, guide_trace)
    492                 return elbo * (-1.0 / self.num_particles)
    493 

~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in _compute_dice_elbo(model_trace, guide_trace)
    149     # Accumulate marginal model costs.
    150     marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
--> 151             model_trace, guide_trace)
    152     if log_factors:
    153         dim_to_size = {}

~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in _compute_model_factors(model_trace, guide_trace)
    142                     log_factors.setdefault(t, []).append(logprob)
    143                     scales.append(site["scale"])
--> 144         scale = _get_common_scale(scales)
    145     return marginal_costs, log_factors, ordering, enum_dims, scale
    146 

/usr/lib/python3.6/contextlib.py in inner(*args, **kwds)
     50         def inner(*args, **kwds):
     51             with self._recreate_cm():
---> 52                 return func(*args, **kwds)
     53         return inner
     54 

~/src/envpgran/lib/python3.6/site-packages/pyro/infer/traceenum_elbo.py in _get_common_scale(scales)
     38     if len(scales_set) != 1:
     39         raise ValueError("Expected all enumerated sample sites to share a common poutine.scale, "
---> 40                          "but found {} different scales.".format(len(scales_set)))
     41     return scales[0]
     42 

ValueError: Expected all enumerated sample sites to share a common poutine.scale, but found 2 different scales.

I have no chance to figure out from the trace, neither from docs, what is wrong. Any suggestions, please?

It is “minimal example”.
The changing hmmdim to 1 fix the error.

Hi, what version of PyTorch are you using? I’m having trouble reproducing your error.

Hi,
here is a list of relevant installed packages

numpyro==0.2.3
pyro-api==0.1.1
pyro-ppl==1.3.0
torch==1.4.0
torchfile==0.1.0
torchvision==0.5.0

And you?)

I update to 1.3.1 and switch from JitTraceEnum_ELBO to TraceEnum_ELBO and the script works properly.

But what the reason?