Inference with DiscreteHMM

  1. Does Pyro have a function to get the most likely sequence of hidden states (Viterbi) for the model containing DiscreteHMM?

  2. Is there a way to implement a Discrete HMM model in Pyro where multiple random variables depend the same discrete hidden state?

Thanks

Updates:

I was able to do the Viterbi-like MAP inference. I fit the model using DiscreteHMM distribution and then apply infer_discrete with temperature=0 to identical model but written using pyro.markov primitive. Is this the intended way of doing discrete inference?

1 Like

Hi @ordabayev, yes, I like your solution of training with a DiscreteHMM then writing an equivalent model and using infer_discrete for Viterbi inference.

Is there a way to implement a Discrete HMM model in Pyro where multiple random variables depend the same discrete hidden state?

Yes there are a few options:

  1. You can use DiscreteHMM with random variables that can be batched into a single random variable. For example if two downstream Normal variables depend on the latent state, you can concat them into a single distribution:
    obs_dist = Normal(torch.stack([loc1, loc2], dim=-1),
                      torch.stack([scale1, scale2], dim=-1)).to_event(1)
    DiscreteHMM(..., obs_dist)
    
  2. If your distributions cannot be concatenated, you can use the pyro.markov and the non-vectorized encoding of random variables in Pyro:
    for t in pyro.markov(range(T)):
        z = pyro.sample("z_{}".format(t), ...,
                        infer={"enumerate": "parallel"})
        pyro.sample("x1_{}".format(t), dist.Normal(locs[z], 1),
                    obs=x1[t])
        pyro.sample("x1_{}".format(t), dist.Bernoulli(probs[z], 1),
                    obs=x2[t])
    
  3. If you really want to use DiscreteHMM for heterogeneous observations, you could manually construct a Categorical observation distribution whose logits are the sum of log-likelihoods from other distributions. It will be difficult to sample from this (e.g. for forecasting), but if all you want is MAP latent state (e.g. for segmentation), then this will be fast and general.

Good luck, and don’t hesitate to paste successful code snippets here, I’m sure others would benefit!
@fritzo

1 Like

Hi @fritzo. Thank you for your reply. I have a model written using pyro.markov . It works but it is very slow. So I use it only for Viterbi decoding. The solution that seems to work is to stack distributions like you described in 1. To make it work for heterogeneous distributions I created a StackDistributions class.

class StackDistributions(TorchDistribution):
    """
    Stack multiple heterogeneous distributions.
    This is useful when multiple heterogeneous distributions
    depend on the same hidden state in DiscreteHMM.
    Example::
        d1 = dist.Normal(torch.zeros(batch_shape), 1.)
        d2 = dist.Gamma(torch.ones(batch_shape), 1.)
        d = StackDistributions(d1, d2)
    :param sequence of pyro.distributions.TorchDistribution distributions
    """
    arg_constraints = {}  # nothing to be constrained

    def __init__(self, *dists):
        self.dists = dists
        batch_shape = self.dists[0].batch_shape
        event_shape = self.dists[0].event_shape + (len(self.dists),)
        super().__init__(batch_shape, event_shape)

    @property
    def has_rsample(self):
        return all(dist.has_rsample for dist in self.dists)

    def expand(self, batch_shape):
        dists = (dist.expand(batch_shape) for dist in self.dists)
        return type(self)(*dists)

    def sample(self, sample_shape=torch.Size()):
        result = tuple(dist.sample(sample_shape) for dist in self.dists)
        return torch.stack(result, dim=-1)

    def rsample(self, sample_shape=torch.Size()):
        result = tuple(dist.rsample(sample_shape) for dist in self.dists)
        return torch.stack(result, dim=-1)

    def log_prob(self, value):
        values = torch.unbind(value, dim=-1)
        log_probs = tuple(dist.log_prob(value) for dist, value in zip(self.dists, values))
        result = torch.sum(torch.stack(log_probs, -1), -1)
        return result

Briefly, I’m working on a program that detects spots in the single-molecule fluorescence microscopy images and determines binding/dissociation rates using HMM. This is the model written using pyro.markov:

@config_enumerate
def viterbi_model(self, data):
    K_plate = pyro.plate("K_plate", self.K, dim=-2)
    N_plate = pyro.plate("N_plate", data.N, dim=-1)

    init = pi_theta_calc(param("pi"), self.K, self.S)  #  self.S*self.K+1
    trans = theta_trans_calc(param("A"), self.K, self.S) #  self.S*self.K+1, self.S*self.K+1
    pi_m = pi_m_calc(param("lamda"), self.S) #  self.S+1, self.S+1

    with N_plate as batch_idx:
        thetas = []
        theta = pyro.sample("theta", dist.Categorical(init))

        for f in pyro.markov(range(data.F)):
            background = pyro.sample(
                f"background_{f}", dist.Gamma(
                    param(f"d/background_loc")[batch_idx, 0]
                    * param("background_beta"), param("background_beta")))

            theta = pyro.sample(
                f"theta_{f}", dist.Categorical(Vindex(trans)[theta, :]))

            theta_mask = Vindex(self.theta_matrix)[..., 0, theta]
            m_mask = Vindex(self.m_matrix)[..., 0, theta]

            with K_plate:
                m = pyro.sample(f"m_{f}", dist.Categorical(Vindex(pi_m)[m_mask]))
                height = pyro.sample(
                    f"height_{f}", dist.Gamma(
                        param("height_loc")[m] * param("height_beta")[m],
                        param("height_beta")[m]))
                width = pyro.sample(
                    f"width_{f}", ScaledBeta(
                        param("width_mode"),
                        param("width_size"), 0.5, 2.5))
                x = pyro.sample(
                    f"x_{f}", ScaledBeta(
                        0, self.size[theta_mask], -(data.D+1)/2, data.D+1))
                y = pyro.sample(
                    f"y_{f}", ScaledBeta(
                        0, self.size[theta_mask], -(data.D+1)/2, data.D+1))

            width = width * 2.5 + 0.5
            x = x * (data.D+1) - (data.D+1)/2
            y = y * (data.D+1) - (data.D+1)/2

           # calculate the shape of the 2-D Gaussian spot based on sampled parameters
            locs = data.loc(height, width, x, y, background, batch_idx, None, f)
            pyro.sample(
                f"data_{f}", self.CameraUnit(
                    locs, param("gain"), param("offset")).to_event(2),
                obs=data[batch_idx, f])
            thetas.append(theta)
    return thetas

And here is the model written using DiscreteHMM:

def discretehmm_model(self, data, prefix):
    K_plate = pyro.plate("K_plate", self.K, dim=-2)
    N_plate = pyro.plate("N_plate", data.N, dim=-1)

    with N_plate as batch_idx:
        background = pyro.sample(
            "background", dist.Gamma(
                param(f"{prefix}/background_loc")[batch_idx]
                * param("background_beta"), param("background_beta")).expand([len(batch_idx), data.F]).to_event(1))

        with K_plate:
            pi_m = pi_m_calc(param("lamda"), self.S)
            m_logits = Vindex(pi_m)[self.m_matrix].log()
            h_dist = EnumDistribution(dist.Gamma(
                    param("height_loc") * param("height_beta"),
                    param("height_beta")), m_logits)
            x_dist = ScaledBeta(
                    0, self.size[self.theta_matrix], -(data.D+1)/2, data.D+1)
            y_dist = ScaledBeta(
                    0, self.size[self.theta_matrix], -(data.D+1)/2, data.D+1)
            hxy_dist = StackDistributions(h_dist, x_dist, y_dist)

            init = pi_theta_calc(param("pi"), self.K, self.S).log()  # state_dim
            trans = theta_trans_calc(param("A"), self.K, self.S).log() # state_dim, state_dim
            hmm_dist = dist.DiscreteHMM(init, trans, hxy_dist, duration=data.F)

            hxy = pyro.sample("hxy", hmm_dist)
            height, x, y = torch.unbind(hxy, dim=-1)

            width = pyro.sample(
                "width", ScaledBeta(
                    param("width_mode"),
                    param("width_size"), 0.5, 2.5).expand([data.F]).to_event(1))

        width = width * 2.5 + 0.5
        x = x * (data.D+1) - (data.D+1)/2
        y = y * (data.D+1) - (data.D+1)/2

        # calculate the shape of the 2-D Gaussian spot based on sampled parameters
        locs = data.loc(height, width, x, y, background, batch_idx)
        pyro.sample(
            "data", self.CameraUnit(
                locs, param("gain"), param("offset")).to_event(3),
            obs=data[batch_idx])

I fit the data using the second model and then use the first model to do inference:

        guide_trace = poutine.trace(self.viterbi_guide).get_trace(self.data)
        trained_model = poutine.replay(
            poutine.enum(self.viterbi_model, first_available_dim=-4), trace=guide_trace)
        thetas = infer_discrete(
            trained_model, temperature=0, first_available_dim=-4)(data=self.data)
        thetas = torch.stack(thetas, dim=-1)
        self.predictions["z"] = (thetas > 0).cpu().data