Categorical Embedding/Mapping

Hello,

I am relatively new to this language and concept, so forgive me for my naivete.

I am attempting to use the inferential framework allowed by Pyro. I am finding the parameterized distributions of PyTorch to be relatively intuitive. The model/guide framework up to this point seems to run without issue, but I believe I have a “kink”, so to speak, to the loss back-propagation (since I cannot converge my model to the expected output).

In my model/guide I have functional mapping of a->a'. For effect, this map can be represented as a square, 2D matrix with a in the rows and a' in the columns. Each value of this matrix represents a probability of a being observed as a'. I am hoping to utilize this matrix P(a|a') and its easily computable cousin P(a'|a) to infer the input of my model given a certain output measurement further down the line.

In my efforts, I have come across the concept of PyTorch Transformed Distributions, but I am uncertain of their utility in this situation. Additionally, I have looked into Categorical distributions.

Any tips/tricks/help would be very much appreciated!

Hi @getzinmw, this sounds like a Categorical distribution to me. I would start with a model like

import torch
import pyro.distributions as dist
import pyro
from torch.distributions import constraints

# Set these to actual cardinality:
A_SIZE = 100
A_PRIME_SIZE = 50

def model(a_input, a_prime_output):
    assert len(a_input) == len(a_prime_output)
    trans_matrix = pyro.param("trans_matrix",
                              torch.ones(A_SIZE, A_PRIME_SIZE),
                              constraint=constraints.simplex)
    with pyro.plate("data", len(a_input)):
        p = trans_matrix[a_input]
        pyro.sample("obs", dist.Categorical(p), obs=a_prime_output)

Since there are no latent variables, you can train that with an empty guide

from pyro.infer import Trace_ELBO, SVI
from pyro.optim import Adam

def guide(*args):
    pass

svi = SVI(model, guide, Adam({"lr": 0.1}, Trace_ELBO)
for step in range(1000):
    svi.step(A_INPUT, A_PRIME_OUTPUT)  # feed in your data here