HMM - Definition
This is my definition of a hidden markov model.
It’s a common example, where you play a game of dice against a casino.
They either use a fair or a loaded dice, the loaded dice draws 6 more often than the fair dice.
import torch
import pyro
import pyro.infer
def transitionFair():
temp = pyro.sample('tempFairTransition', pyro.distributions.Bernoulli(0.95))
if temp.item() == 1.0:
return 'fair'
else:
return 'unfair'
def transitionUnfair():
temp = pyro.sample('tempUnfairTransition', pyro.distributions.Bernoulli(0.9))
if temp.item() == 1.0:
return 'unfair'
else:
return 'fair'
def observeFair():
temp = pyro.sample('tempFairObservation', pyro.distributions.Categorical(torch.tensor([1/6, 1/6, 1/6, 1/6, 1/6, 1/6 ])))
return temp.item() + 1
def observeUnfair():
temp = pyro.sample('tempUnfairObservation', pyro.distributions.Categorical(torch.tensor([1/10,1/10,1/10,1/10,1/10,1/2])))
return temp.item() + 1
def hmm(chainLength, startState):
states = []
observations = []
currentState = startState
for x in range(chainLength - 1):
if currentState == "fair":
states.append(transitionFair())
observations.append(observeFair())
currentState = states[x]
else :
states.append(transitionUnfair())
observations.append(observeUnfair())
currentState = states[x]
return states, observations
##Inference on the model
What I want to do is, to solve the Likelyhood-Problem and the Decoding Problem.
Likelyhood: Given a sequence of observations, what is the probability, that the sequence is emitted by the HMM.
Decoding: Given a sequence of observations, what is the most likely HiddenStatePath, that has produced this observation sequence. (Usually solved by Viterbi)
As I understood, SVI ist the most common Inference algorithm for Pyro. I wanted to solve this problems with it, but I would be happy do solve this Problem with Pyro at all, so if anyone has an idea that does not involve SVI, that would be also fine for me.
Reading about SVI I understood I need to rewrite the model and inline my helper functions, without this step I will not be able to write a proper guide function right ? -> This is what I did:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, infer_discrete,config_enumerate
from pyro.optim import Adam
pyro.enable_validation(True)
pyro.clear_param_store()
def hmm(chainLength, startState):
states = []
observations = []
currentState = startState
for x in range(chainLength - 1):
if currentState == "fair":
stateSampled = ''
drawstate = pyro.sample('tempFairTransition' + str(x), pyro.distributions.Bernoulli(0.95))
if drawstate.item() == 1.0:
stateSampled = 'fair'
else:
stateSampled = 'unfair'
states.append(stateSampled)
observationSampled = 0
drawObservation = pyro.sample('tempFairObservation', pyro.distributions.Categorical(torch.tensor([1/6, 1/6, 1/6, 1/6, 1/6, 1/6 ])), obs=torch.tensor(5))
observationSampled = drawObservation.item() + 1
observations.append(observationSampled)
currentState = states[x]
else :
stateSampled = ''
drawstate = pyro.sample('tempUnfairTransition' + str(x), pyro.distributions.Bernoulli(0.9))
if drawstate.item() == 1.0:
stateSampled = 'fair'
else:
stateSampled = 'unfair'
states.append(stateSampled)
observationSampled = 0
drawObservation = pyro.sample('tempUnfairObservation', pyro.distributions.Categorical(torch.tensor([1/10,1/10,1/10,1/10,1/10,1/2])),obs=torch.tensor(5))
observationSampled = drawObservation.item() + 1
observations.append(observationSampled)
currentState = states[x]
return states, observations
def guide(chainLength, startstate):
#what does my guide has to look like ?
# I guess I need param to train
adam_params = {"lr": 0.0005}
optimizer = Adam(adam_params)
svi = SVI(hmm, guide, optimizer, loss=Trace_ELBO())
n_steps = 2501
for step in range(n_steps):
svi.step(10,'unfair')
if step % 100 == 0:
print(pyro.param("whatever"))
The main problem is, I have no idea how to define my guide function. In addition I am not sure about certain effects due to the definition of my model. How do I deal with the case, that my guide function has to invoke an if-else case. And how do I deal with the current state depending on the previous one ?
I hope someone has an idea.
Thanks already for reading it!
Greetings
CodeInvolved