I am trying to create a simple program representing the factor graph below:
Where s = Cat(D)
(so s
is sampled from a vector D
, which is a standard simplex; in the above graph it represents prior beliefs about the value of s
), and o = As
(where A
is a matrix mapping latent values of s
to observation space o
). o
and s
will be one-hot vectors to represent discrete states and observations.
Below I have my attempt to create this in Pyro with a simple example, using two possible values of o
and s
(which means A
is 2x2. Of course, in this simple model s
could be solved analytically, but I wanted to experiment to see if I could get it working in Pyro. My confusion is with regards to constructing the guide in this case. I have tried it with/without the simplex constraint, but adding the simplex constraint to the sd
parameter doesn’t seem to work (the final value for sd
yielding [1,1]
.
Even without the simplex constraint (using dist.constraints.positive
in the below example), sd
approaches the correct value (which, analytically, I believe should be [.96, .04]
for the given values of D
, A
, and ob
, reflecting the posterior beliefs in the value of s
), but is still quite inconsistent across runs, and even the value it finishes on is not quite as close to what I would want. Furthermore, the plot of losses fluctuates quite a bit, which I believe is indicative of the underlying issue.
My hunch is that it has something to with my choice of distribution in my guide
, but I’m not sure exactly what to use otherwise. Any help is appreciated!
import pyro
import torch
from torch import tensor
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import matplotlib.pyplot as plt
import numpy as np
from torch.nn import Softmax
A = tensor([
[0.9, 0.1],
[0.1, 0.9]
])
D = tensor([
[0.5],
[0.5]
])
Ns = D.shape[0]
def one_hot(num, sz):
mat = torch.zeros([1, sz])
mat[0][num] = 1.
return mat.T
def model(o):
s = pyro.sample("s",
dist.Categorical(D.T)
)
s = one_hot(s, Ns)
As = torch.matmul(A, s)
pyro.sample("o",
dist.Categorical( As.T ),
obs = o
)
def guide(o):
sd = pyro.param("sd", tensor([
[0.5],
[0.5]
]), constraint = dist.constraints.positive)
s = pyro.sample("s", dist.Categorical(sd.T))
adam = Adam({"lr": 0.1, "betas": (0.90, 0.999)})
svi = SVI(model, guide, adam, loss=Trace_ELBO())
param_vals = []
losses = []
ob = tensor([0])
for _ in range(1000):
losses.append(svi.step(ob))
param_vals.append({k: pyro.param(k) for k in ["sd"]})
plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss")
plt.show()
fm = Softmax(dim=0)
print(fm(param_vals[-1]["sd"]))