Adapt HMM from Bernoulli's to categorical

Hi!

I try to adapt HMM example to categorical distribution. It works well with Bernoulli distribution or with categorical with two states but no more.

Here is my code

import math
import os
import torch

torch.set_default_tensor_type( 'torch.cuda.FloatTensor' )

import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO
import pyro.distributions as dist

from pyro import poutine
from pyro.infer.autoguide import AutoDelta

smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 1000
n_steps_print = 10 ** ( np.int( np.log10( n_steps ) ) - 1 )

assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)

pyro.clear_param_store()

A = [ 0,1,2,2,2,2,2,2,2 ]
B = [ 0,1,2,0,1,2,0,1,2 ]
C = [ 0,0,0,0,0,0,0,1,2 ]
data_3 = [ A,C,A,C,A,C,A,C, B,B,B,B,B ]
data_3 = torch.tensor( data_3, dtype=torch.float64 )

def model_3( data ) :
    hmmdim = 3
    datadim = 3
    x_q = pyro.sample( "x_q", dist.Dirichlet( torch.ones( [ hmmdim, hmmdim ] ) ).to_event( 1 ) )
    y_q = pyro.sample( "y_q", dist.Dirichlet( torch.ones( [ hmmdim, datadim ] ) ).to_event( 1 ) )
    x = 0
    plate_seq = pyro.plate( "seq", len( data[0] ) )
    for j in pyro.markov( range( len( data ) ) ) :
        seq = data[j]
        x = pyro.sample(  "x_{}".format( j )
                        , dist.Categorical( x_q[x] )
                        , infer={ "enumerate": "parallel" } )
        for i in plate_seq :
            pyro.sample(  "obs_{}_{}".format( j, i )
                        , dist.Categorical( y_q[x] )
                        , obs=seq[i] )

model_number = 3

data = globals()["data_" + str( model_number )]
model = globals()["model_" + str( model_number )]

guide = AutoDelta( poutine.block( model, expose_fn=lambda msg: msg["name"].endswith( "_q" ) ) )

adam_params = { "lr": 0.05, "betas": ( 0.90, 0.999 ) }
optimizer = Adam( adam_params )

svi = SVI( model, guide, optimizer, loss=TraceEnum_ELBO() )

for step in range( n_steps ) :
    loss = svi.step( data )
    if step % n_steps_print == 0:
        print(' . %.3f' % loss, end='')

print()

for k, v in sorted( pyro.get_param_store().items() ) :
    b = v.cpu().detach().numpy()
    print( k )
    print( b.shape )
    print( b.round( 2 ) )

Absolutely no idea what’s wrong with it.
Any suggestions?

Typical solution is

(3, 3)
AutoDelta.x_q
[[0.33 0.33 0.33]
 [0.33 0.33 0.33]
 [0.33 0.33 0.33]]
(3, 3)
AutoDelta.y_q
[[0.4 0.2 0.4]
 [0.4 0.2 0.4]
 [0.4 0.2 0.4]]

Expected is like

(3, 3)
AutoDelta.x_q
[[0.   1.   0.]
 [1.   0.   0.]
 [0.   0.   1.]]
(3, 3)
AutoDelta.y_q
[[0.1 0.1 0.8]
 [0.8 0.1 0.1]
 [0.3 0.3 0.3]]

Hi @Vanka,

I suspect the issue is with the indexing y_q[x]. You might try using the Vindex helper as described in the enumeration tutorial. Also ideally you could vectorize that “seq” plate, maybe something like

x = 0
plate_seq = pyro.plate("seq", len(data[0]))
for j in pyro.markov(range(len(data))):
    x = pyro.sample("x_{}".format(j),
                    dist.Categorical(Vindex(x_q)[..., x, :]),
                    infer={"enumerate": "parallel"})
    with plate_seq:
        pyro.sample("obs_{}".format(j),
                    dist.Categorical(Vindex(y_q)[..., x, :]),
                    obs=data[j])

Let me know if you get something like that working!

Hi. I might be wrong but it seems from your data that datadim is 9 (length of A, B, and C), three hidden states (A, B, C), and batch size of 1 (one sequence).

import math
import os
import torch
import numpy as np

torch.set_default_tensor_type( 'torch.cuda.FloatTensor' )

import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO
import pyro.distributions as dist

from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 500
n_steps_print = 10 ** ( np.int( np.log10( n_steps ) ) - 1 )

assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)

pyro.clear_param_store()

A = [ 0,1,2,2,2,2,2,2,2 ]
B = [ 0,1,2,0,1,2,0,1,2 ]
C = [ 0,0,0,0,0,0,0,1,2 ]
data_3 = [ A,C,A,C,A,C,A,C, B,B,B,B,B ]
data_3 = torch.tensor( data_3, dtype=torch.float64 )

def model_3( data ) :
    hmmdim = 3
    datadim = 9
    #x_q = pyro.sample( "x_q", dist.Dirichlet( torch.ones( hmmdim, hmmdim )/3 ).to_event( 1 ) )
    x_q = pyro.sample( "x_q", dist.Dirichlet( torch.eye( hmmdim, hmmdim ) * 0.1 + 0.3 ).to_event( 1 ) )
    #y_q = pyro.sample( "y_q", dist.Dirichlet( torch.ones( [ hmmdim, datadim ] ) ).to_event( 1 ) )
    y_q = pyro.sample( "y_q", dist.Dirichlet( torch.ones( [ hmmdim, datadim, 3 ] ) ).to_event( 2 ) )
    x = 0
    plate_seq = pyro.plate( "seq", len( data[0] ) )
    for j in pyro.markov( range( len( data ) ) ) :
        seq = data[j]
        x = pyro.sample(  "x_{}".format( j )
                        , dist.Categorical( Vindex(x_q)[..., x, :] )
                        , infer={ "enumerate": "parallel" } )
        #for i in plate_seq :
        #    pyro.sample(  "obs_{}_{}".format( j, i )
        #                , dist.Categorical( y_q[x] )
        #                , obs=seq[i] )
        with plate_seq:
            pyro.sample(  "obs_{}".format( j )
                        #, dist.Categorical( y_q[x] )
                        , dist.Categorical( Vindex(y_q)[..., x.squeeze(-1), :, :] )
                        , obs=seq )

model_number = 3

data = globals()["data_" + str( model_number )]
model = globals()["model_" + str( model_number )]

guide = AutoDelta( poutine.block( model, expose_fn=lambda msg: msg["name"].endswith( "_q" ) ) )

adam_params = { "lr": 0.05, "betas": ( 0.90, 0.999 ) }
optimizer = Adam( adam_params )

svi = SVI( model, guide, optimizer, loss=TraceEnum_ELBO() )

for step in range( n_steps ) :
    loss = svi.step( data )
    if step % n_steps_print == 0:
        print(' . %.3f' % loss, end='')

print()

for k, v in sorted( pyro.get_param_store().items() ) :
    b = v.cpu().detach().numpy()
    print( k )
    print( b.shape )
    print( b.round( 2 ) )

Then the output says that A mostly converts to C, C mostly converts to A, and B only converts to B:

AutoDelta.x_q
(3, 3)
[[0.03 0.   0.97]
 [0.   1.   0.  ]
 [0.99 0.01 0.  ]]
AutoDelta.y_q
(3, 9, 3)
[[[1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]
  [0. 0. 1.]
  [0. 0. 1.]
  [0. 0. 1.]
  [0. 0. 1.]
  [0. 0. 1.]
  [0. 0. 1.]]

 [[1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]
  [1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]
  [1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]]

 [[1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]]]

I thought that Vindex is for enhancement working code, like victorization. I miss that

Due to advanced indexing semantics, the expression p[..., x, y, :] will work correctly without enumeration, but is incorrect when x or y is enumerated.

Thank you to point it out!

I use Vindex both in Vindex(x_q) and Vindex(y_q) but, get no result neither with vectorization of plate_seq nor without. Here is the code

Code
import math
import os
import torch

torch.set_default_tensor_type( 'torch.cuda.FloatTensor' )

import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO
import pyro.distributions as dist

from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 500
n_steps_print = 10 ** ( np.int( np.log10( n_steps ) ) - 1 )

assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)

pyro.clear_param_store()

A = [ 0,1,2,3,4, 4,4,4,4,4, 4,4,4,4,4 ]
B = [ 0,1,2,3,4, 0,1,2,3,4, 0,1,2,3,4 ]
C = [ 0,0,0,0,0, 0,0,0,0,0, 0,1,2,3,4 ]
data_3 = [ A,C,A,C,A,C,A,C, B,B,B,B,B ]
data_3 = torch.tensor( data_3, dtype=torch.float64 )

def model_3( data ) :
    hmmdim = 3
    datadim = 5
    x_q = pyro.sample( "x_q", dist.Dirichlet( torch.ones( [ hmmdim, hmmdim ] ) ).to_event( 1 ) )
    y_q = pyro.sample( "y_q", dist.Dirichlet( torch.ones( [ hmmdim, datadim ] ) ).to_event( 1 ) )
    x = 0
    plate_seq = pyro.plate( "seq", len( data[0] ) )
    for j in pyro.markov( range( len( data ) ) ) :
        seq = data[j]
        x = pyro.sample(  "x_{}".format( j )
                        #, dist.Categorical( x_q[x] )
                        , dist.Categorical( Vindex( x_q )[..., x, :] )
                        , infer={ "enumerate": "parallel" } )
#         for i in plate_seq :
#             pyro.sample(  "obs_{}_{}".format( j, i )
#                         #, dist.Categorical( y_q[x] )
#                         , dist.Categorical( Vindex( y_q )[..., x, :] )
#                         , obs=seq[i] )
        with plate_seq:
            seq_dist = dist.Categorical( Vindex( y_q )[..., x, :] )
            z = pyro.sample( "obs_{}".format(j)
                       , seq_dist
                       , obs=data[j] )

model_number = 3

data = globals()["data_" + str( model_number )]
model = globals()["model_" + str( model_number )]

guide = AutoDelta( poutine.block( model, expose_fn=lambda msg: msg["name"].endswith( "_q" ) ) )

# first_available_dim = -2
# guide_trace = poutine.trace(guide).get_trace( data )
# model_trace = poutine.trace(
#     poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace( data )

# print(model_trace.format_shapes())
# raise

adam_params = { "lr": 0.05, "betas": ( 0.90, 0.999 ) }
optimizer = Adam( adam_params )

svi = SVI( model, guide, optimizer, loss=TraceEnum_ELBO() )

for step in range( n_steps ) :
    loss = svi.step( data )
    if step % n_steps_print == 0:
        print(' . %.3f' % loss, end='')

print()

for k, v in sorted( pyro.get_param_store().items() ) :
    b = v.cpu().detach().numpy()
    print( k )
    print( b.shape )
    print( b.round( 2 ) )

I change data to datadim = 5 to avoid confusion with dimensions.

As a result I constantly get

Result
AutoDelta.x_q
(3, 3)
[[0.33 0.33 0.33]
 [0.33 0.33 0.33]
 [0.33 0.33 0.33]]
AutoDelta.y_q
(3, 5)
[[0.32 0.12 0.12 0.12 0.32]
 [0.32 0.12 0.12 0.12 0.32]
 [0.32 0.12 0.12 0.12 0.32]]

I thought that the reason in dimension inconsistency, check it, but can’t identify that any thing is wrong.

Thank you for reply fritzo!

Your model works perfectly! And it is much more fine then mine. I just have three dice (== hmmdim) with three sides (== datadim, or five sides, see post below). The model with 3 states that match different sequences of 9 dice where each dice with 3 sides, and is good but i plan to use a kind of the model with about 50 states, sequences of 15000 dice, with each dice of 4000 states, it won’t fit in 10GB.

Please, take look at my post below.

Thank you for reply!

V.

The initial parameters was wrong.

Code
def model_3( data ) :
    hmmdim = 3
    datadim = 5
    x_q = pyro.param( "x_q"
                     , torch.eye( hmmdim ) * .1 + .3
                     , constraint=constraints.simplex )
    y_q = pyro.param(  "y_q"
                     , torch.ones( [ hmmdim, datadim ] )
                     , constraint=constraints.simplex )
    x = 0
    plate_seq = pyro.plate( "seq", len( data[0] ), dim=-1 )
    for j in pyro.markov( range( len( data ) ) ) :
        seq = data[j]
        x = pyro.sample(  "x_{}".format( j )
                        , dist.Categorical( Vindex( x_q )[..., x, :] )
                        , infer={ "enumerate": "parallel" } )
        with plate_seq:
            z = pyro.sample( "obs_{}".format(j)
                       , dist.Categorical( Vindex( y_q )[..., x, :] )
                       , obs=data[j] )