We are trying to do bayesian rank aggregation based on this paper.
Here they have
NxBxT where is N is number of data points, B is number of base rankers and T is the number of elements in each list
Likelihood is given as an exponential function - where the input requires both R, alpha and rho
NxT where N is the number of data points and T is the number of elements in each list
prior given as uniform Categorical of all permutations of the list list(range(T))
alpha = scale parameters
prior given as truncated exponential
We are trying to aggregate each baserankers, B, ranking of each element, N, by infering the unobserved consensus ranking rho.To implement each of the distributions we have found out that we would have to make custom distributions inheriting from
the Distributions class in torch like
from torch.distributions.distribution import Distribution class custom(Distribution): def __init__(self,): def sample(): def log_prob(x):
Where we need to implement the methods sample and log_prob.
The problem now is that for example to implement the likelihood for R we need to be able to sample R from a categorical distribution.
This is because the way pyro works under the hood is that pyro.sample does
x = dist.sample()
prob = dist.log_prob(x)
2 problems arise out of this:
- we don’t know how to sample R which has 16^factorial(T) permutations
- We can’t calculate the log_prob without having alpha and rho as inputs
this is an example of how I imagined our code would look like
from itertools import permutations import torch import math def model(R): N, B, T = R.shape Pn = torch.tensor(list(permutations(range(T)))) lambda_ = 0.1 tfac = math.factorial(T) probs = 1/float(tfac)*torch.ones(tfac) with pyro.plate("query", N): alpha = pyro.sample("alpha", dist.Exponential(lambda_)) # scale rho_idx = pyro.sample("rho", dist.Categorical(probs = probs)) # index of rho rho = Pn[rho_idx] R_obs = pyro.sample("R", likelihood(R, alpha, rho), obs = R) return R
To sum up the issue, how do we sample R from the distribution and how do we add these custom arguments to the log_prob methods