Special Hidden Markov Model - Inference not possible?

HMM - Definition

This is my definition of a hidden markov model.
It’s a common example, where you play a game of dice against a casino.
They either use a fair or a loaded dice, the loaded dice draws 6 more often than the fair dice.

import torch
import pyro
import pyro.infer

def transitionFair():
    temp = pyro.sample('tempFairTransition', pyro.distributions.Bernoulli(0.95))
    if temp.item() == 1.0:
        return 'fair'
    else:
        return 'unfair'

def transitionUnfair():
    temp = pyro.sample('tempUnfairTransition', pyro.distributions.Bernoulli(0.9))
    if temp.item() == 1.0:
        return 'unfair'
    else:
        return 'fair'

def observeFair():
    temp = pyro.sample('tempFairObservation', pyro.distributions.Categorical(torch.tensor([1/6, 1/6, 1/6, 1/6, 1/6, 1/6 ])))
    return temp.item() + 1

def observeUnfair():
    temp = pyro.sample('tempUnfairObservation', pyro.distributions.Categorical(torch.tensor([1/10,1/10,1/10,1/10,1/10,1/2])))
    return temp.item() + 1

def hmm(chainLength, startState):
    states = []
    observations = []
    currentState = startState

    for x in range(chainLength - 1):
        if currentState == "fair":
            states.append(transitionFair())
            observations.append(observeFair())
            currentState = states[x]
        else :
            states.append(transitionUnfair())
            observations.append(observeUnfair())
            currentState = states[x]
    return states, observations

##Inference on the model
What I want to do is, to solve the Likelyhood-Problem and the Decoding Problem.

Likelyhood: Given a sequence of observations, what is the probability, that the sequence is emitted by the HMM.

Decoding: Given a sequence of observations, what is the most likely HiddenStatePath, that has produced this observation sequence. (Usually solved by Viterbi)

As I understood, SVI ist the most common Inference algorithm for Pyro. I wanted to solve this problems with it, but I would be happy do solve this Problem with Pyro at all, so if anyone has an idea that does not involve SVI, that would be also fine for me.

Reading about SVI I understood I need to rewrite the model and inline my helper functions, without this step I will not be able to write a proper guide function right ? -> This is what I did:

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, infer_discrete,config_enumerate
from pyro.optim import Adam

pyro.enable_validation(True)
pyro.clear_param_store()

def hmm(chainLength, startState):
    states = []
    observations = []
    currentState = startState

    for x in range(chainLength - 1):
        if currentState == "fair":
            stateSampled = ''
            drawstate = pyro.sample('tempFairTransition' + str(x), pyro.distributions.Bernoulli(0.95))
            if drawstate.item() == 1.0:
                stateSampled = 'fair'
            else:
                stateSampled = 'unfair'
            states.append(stateSampled)
            observationSampled = 0
            drawObservation = pyro.sample('tempFairObservation', pyro.distributions.Categorical(torch.tensor([1/6, 1/6, 1/6, 1/6, 1/6, 1/6 ])), obs=torch.tensor(5))
            observationSampled = drawObservation.item() + 1
            observations.append(observationSampled)
            currentState = states[x]
        else :
            stateSampled = ''
            drawstate = pyro.sample('tempUnfairTransition' + str(x), pyro.distributions.Bernoulli(0.9))
            if drawstate.item() == 1.0:
                stateSampled = 'fair'
            else:
                stateSampled = 'unfair'
            states.append(stateSampled)
            observationSampled = 0
            drawObservation = pyro.sample('tempUnfairObservation', pyro.distributions.Categorical(torch.tensor([1/10,1/10,1/10,1/10,1/10,1/2])),obs=torch.tensor(5))
            observationSampled = drawObservation.item() + 1
            observations.append(observationSampled)
            currentState = states[x]
            return states, observations

def guide(chainLength, startstate):
    #what does my guide has to look like ?
   # I guess I need param to train 

adam_params = {"lr": 0.0005}
optimizer = Adam(adam_params)
svi = SVI(hmm, guide, optimizer, loss=Trace_ELBO())

n_steps = 2501
for step in range(n_steps):
  svi.step(10,'unfair')
  if step % 100 == 0:
    print(pyro.param("whatever"))

The main problem is, I have no idea how to define my guide function. In addition I am not sure about certain effects due to the definition of my model. How do I deal with the case, that my guide function has to invoke an if-else case. And how do I deal with the current state depending on the previous one ?

I hope someone has an idea.

Thanks already for reading it!

Greetings
CodeInvolved

Hi @CodeInvolved,

Your tasks looks possible in Pyro. I think you’ll need to rewrite you observation statements to use the obs kwarg to pyro.sample("name", dist, obs=___). Take a look a the intro Pyro tutorials to see how that works.

I need to rewrite the model and inline my helper functions

You can safely keep your helper functions out-of-line, Pyro doesn’t care. I often use helper functions.

SVI … solve this Problem with Pyro

Since this problem is entirely discrete, you can use infer_discrete to compute a Viterbi decoding and TraceEnum_ELBO.loss to compute (negative log) likelihood. Using either version of your hmm model:

from pyro.infer import config_enumerate, infer_discrete, TraceEnum_ELBO

# Compute Viterbi decoding.
decoder = infer_discrete(config_enumerate(hmm), temperature=0)
viterbi_states, observations = decoder(chainLength, startstate)

# Compute likelihood.
def guide(*args):
    pass  # a trivial guide
elbo = TraceEnum_ELBO()
neg_log_likelihood = elbo.loss(model, guide, chainLength, startstate)
likelihood = math.exp(-neg_log_likelihood)