Dear awesome forum,
Here is my problem: I am sampling from a BetaBinomial multiple times and would like to transform the sampled values into counts, similar to a single draw from a Multinomial. For instance:
total_count = 6
n_samples = 4
values = dist.BetaBinomial(torch.tensor([1.]), torch.tensor([2.]), total_count = total_count).sample(sample_shape = (n_samples,)).view(-1)
then transform:
values = [2, 2, 0, 4] ## transform this
counts = [1, 0, 2, 0, 1, 0, 0]
## of size total_count, where the bin at index 0 correspond to the count of value 0, index 1 to the count of value 1 and so on
This should then be used such that
W = pyro.sample('W', dist.TransformedDistribution(dist.BetaBinomial(torch.tensor([1.,....]), torch.tensor([2.,....], total_count = torch.tensor([vocab_size-1])), [CountTransform(vocab_size)]), obs=docs)
I would like to have a transform for this and tried the following code, which does however not work since I don’t know how to define log_abs_det_jacobian
.
class CountTransform(Transform):
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
def __init__(self, total_count: int = 6, normalized: bool = False):
super().__init__(cache_size=1)
self.total_count = total_count
self.normalized = normalized
#self._cache_size = 0
self._cached_logDetJ = None
def __eq__(self, other):
return type(self) == type(other) and self.p == other.p
def _call(self, x):
return to_batched_count_vector(x, total_count = self.total_count, normalized = self.normalized)
def _inverse(self, y):
return reverse_count_vector(y)
def log_abs_det_jacobian(self, x, y, intermediates=None):
return x
def to_batched_count_vector(self, values, vocab_size = 6, normalized = True):
if len(values.shape) == 1:
values = values.view(1,-1)
assert len(values.shape) == 2
to_add = torch.ones(values.shape).type(torch.int64)
base = torch.zeros(values.shape[0], vocab_size, dtype=torch.int64)
count_tensor = base.scatter_add_(1, values.type(torch.int64), to_add)
if normalized:
count_tensor = torch.nn.functional.normalize(count_tensor.float(), p=1, dim=1)
return count_tensor
def reverse_count_vector(self, count_vector):
count_vector = count_vector.long()
if len(count_vector.shape) == 1:
count_vector = count_vector.view(1,-1)
assert len(count_vector.shape) == 2
values = list()
for i, doc in enumerate(range(0,count_vector.shape[0])):
notnull_value = torch.nonzero(count_vector[i,:] != 0).view(-1)
notnull_count = count_vector[i,:][notnull_value]
doc_values = list()
for j, word_index in enumerate(notnull_value):
doc_values += [word_index.item()] * notnull_count[j]
values.append(doc_values)
if len(values) == 1:
values = torch.tensor(values)
return values