Hi!
I try to adapt HMM example to categorical distribution. It works well with Bernoulli distribution or with categorical with two states but no more.
Here is my code
import math
import os
import torch
torch.set_default_tensor_type( 'torch.cuda.FloatTensor' )
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 1000
n_steps_print = 10 ** ( np.int( np.log10( n_steps ) ) - 1 )
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)
pyro.clear_param_store()
A = [ 0,1,2,2,2,2,2,2,2 ]
B = [ 0,1,2,0,1,2,0,1,2 ]
C = [ 0,0,0,0,0,0,0,1,2 ]
data_3 = [ A,C,A,C,A,C,A,C, B,B,B,B,B ]
data_3 = torch.tensor( data_3, dtype=torch.float64 )
def model_3( data ) :
hmmdim = 3
datadim = 3
x_q = pyro.sample( "x_q", dist.Dirichlet( torch.ones( [ hmmdim, hmmdim ] ) ).to_event( 1 ) )
y_q = pyro.sample( "y_q", dist.Dirichlet( torch.ones( [ hmmdim, datadim ] ) ).to_event( 1 ) )
x = 0
plate_seq = pyro.plate( "seq", len( data[0] ) )
for j in pyro.markov( range( len( data ) ) ) :
seq = data[j]
x = pyro.sample( "x_{}".format( j )
, dist.Categorical( x_q[x] )
, infer={ "enumerate": "parallel" } )
for i in plate_seq :
pyro.sample( "obs_{}_{}".format( j, i )
, dist.Categorical( y_q[x] )
, obs=seq[i] )
model_number = 3
data = globals()["data_" + str( model_number )]
model = globals()["model_" + str( model_number )]
guide = AutoDelta( poutine.block( model, expose_fn=lambda msg: msg["name"].endswith( "_q" ) ) )
adam_params = { "lr": 0.05, "betas": ( 0.90, 0.999 ) }
optimizer = Adam( adam_params )
svi = SVI( model, guide, optimizer, loss=TraceEnum_ELBO() )
for step in range( n_steps ) :
loss = svi.step( data )
if step % n_steps_print == 0:
print(' . %.3f' % loss, end='')
print()
for k, v in sorted( pyro.get_param_store().items() ) :
b = v.cpu().detach().numpy()
print( k )
print( b.shape )
print( b.round( 2 ) )
Absolutely no idea what’s wrong with it.
Any suggestions?
Typical solution is
(3, 3)
AutoDelta.x_q
[[0.33 0.33 0.33]
[0.33 0.33 0.33]
[0.33 0.33 0.33]]
(3, 3)
AutoDelta.y_q
[[0.4 0.2 0.4]
[0.4 0.2 0.4]
[0.4 0.2 0.4]]
Expected is like
(3, 3)
AutoDelta.x_q
[[0. 1. 0.]
[1. 0. 0.]
[0. 0. 1.]]
(3, 3)
AutoDelta.y_q
[[0.1 0.1 0.8]
[0.8 0.1 0.1]
[0.3 0.3 0.3]]