Basic Categorical Distribution Question

I am trying to create a simple program representing the factor graph below:

static_perception

Where s = Cat(D) (so s is sampled from a vector D, which is a standard simplex; in the above graph it represents prior beliefs about the value of s), and o = As (where A is a matrix mapping latent values of s to observation space o). o and s will be one-hot vectors to represent discrete states and observations.

Below I have my attempt to create this in Pyro with a simple example, using two possible values of o and s (which means A is 2x2. Of course, in this simple model s could be solved analytically, but I wanted to experiment to see if I could get it working in Pyro. My confusion is with regards to constructing the guide in this case. I have tried it with/without the simplex constraint, but adding the simplex constraint to the sd parameter doesn’t seem to work (the final value for sd yielding [1,1].

Even without the simplex constraint (using dist.constraints.positive in the below example), sd approaches the correct value (which, analytically, I believe should be [.96, .04] for the given values of D, A, and ob, reflecting the posterior beliefs in the value of s), but is still quite inconsistent across runs, and even the value it finishes on is not quite as close to what I would want. Furthermore, the plot of losses fluctuates quite a bit, which I believe is indicative of the underlying issue.

My hunch is that it has something to with my choice of distribution in my guide, but I’m not sure exactly what to use otherwise. Any help is appreciated!

import pyro
import torch
from torch import tensor

import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO

import matplotlib.pyplot as plt
import numpy as np

from torch.nn import Softmax

A = tensor([
    [0.9, 0.1],
    [0.1, 0.9]
])

D = tensor([
    [0.5],
    [0.5]
])

Ns = D.shape[0]

def one_hot(num, sz):
    mat = torch.zeros([1, sz])
    mat[0][num] = 1.
    return mat.T

def model(o):
    s = pyro.sample("s", 
        dist.Categorical(D.T)
    )

    s = one_hot(s, Ns)
    As = torch.matmul(A, s)

    pyro.sample("o", 
        dist.Categorical( As.T ),
        obs = o
    )

def guide(o):
    sd = pyro.param("sd", tensor([
        [0.5],
        [0.5]
    ]), constraint = dist.constraints.positive)

    s = pyro.sample("s", dist.Categorical(sd.T))

adam = Adam({"lr": 0.1, "betas": (0.90, 0.999)})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

param_vals = []
losses = []

ob = tensor([0])

for _ in range(1000):
    losses.append(svi.step(ob))
    param_vals.append({k: pyro.param(k) for k in ["sd"]})

plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss")
plt.show()

fm = Softmax(dim=0)

print(fm(param_vals[-1]["sd"]))

Update:

Solved several issues on my own.

First, my analytical computation was off: Should be [0.9, 0.1], which makes Pyro’s estimates much more palatable.

Second, I changed the learning rate to 0.01 in Adam, which helps considerably.

Finally, I also rearranged how the “sd” part in the guide is formatted (horizontal rather than row vector, which the simplex constraint vastly prefers. Is there a way to use different indexing for the simplex constraint?). Updated code is found below:

import pyro
import torch
from torch import tensor

import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO

import matplotlib.pyplot as plt
import numpy as np

from torch.nn import Softmax

A = tensor([
    [0.9, 0.1],
    [0.1, 0.9]
])

D = tensor([
    [0.5],
    [0.5]
])

Ns = D.shape[0]

def one_hot(num, sz):
    mat = torch.zeros([1, sz])
    mat[0][num] = 1.
    return mat.T

def model(o):
    s = pyro.sample("s", 
        dist.Categorical(D.T)
    )

    s = one_hot(s, Ns)
    As = torch.matmul(A, s)

    pyro.sample("o", 
        dist.Categorical( As.T ),
        obs = o
    )

def guide(o):
    sd = pyro.param("sd", tensor(
        [0.5, 0.5]), 
        constraint = dist.constraints.simplex
    )

    s = pyro.sample("s", dist.Categorical(sd))

adam = Adam({"lr": 0.01, "betas": (0.90, 0.999)})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

param_vals = []
losses = []

ob = tensor([0])

for _ in range(1000):
    losses.append(svi.step(ob))
    param_vals.append({k: pyro.param(k) for k in ["sd"]})

plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss")
plt.show()

print((param_vals[-1]["sd"]))

My only concern is the loss plots still seem odd (seen below), and can still change wildly run-to-run, so I’m still not confident it’s working completely correctly. But the results are much better now.

the score function elbo gradients that arise in the context of discrete latent variables can be high variance. you may get more stable results if you use multiple particles to compute gradients: Trace_ELBO(num_particles=16)