Efficient computation of "top-k" of categorical sample site

I want to compute top-k for my model. In particular I have an inference model, say

def predict(obs, cluster_centers, cluster_stds):
    cls = pyro.sample("class_idx", Categorical(prior_weights))
    pyro.sample("obs", Normal(cluster_centers[cls], cluster_stds[cls], obs=obs)
    return cls

One can trivially obtain the MLE via viterbi: infer_discrete(predict, temperature=0, first_available_dim=-1) in O(TN^2). Suppose, we, however, want to get the top K most likely class assignments. In pyro, one could sample m times from infer_discrete(predict, temperature=1, first_available_dim=-1) to get an estimate for the different class likelyhoods in O(mTN^2). One could alternatively condition Viterbi on all posisble observations and extract log_prob_sum from the trace in (at least) O(kTN^2).

Is there any way to directly extract the smoothed probabilities from the forward-backward algorithm in pyro without sampling from it (ideally in O(TN^2))? Thanks so much!

Hi @npw, two options are to

  1. Try to use TraceEnum_ELBO.compute_marginals()
  2. Train a guide to predict the category probabilities, as described here in the Gaussian Mixture Model tutorial.
1 Like

Thanks so much for your response @fritzo, oddly I seem to be getting a key error in some internal code for the compute_marginals(), any idea if this is expected:

def predict(obs, cluster_centers, cluster_stds):
    with pyro.plate("num_obs", len(obs)):
        cls = pyro.sample("class_idx", Categorical(prior_weights))
        pyro.sample("obs", Normal(cluster_centers[cls], cluster_stds[cls], obs=obs)
    return cls
# this works
losser.loss(predict,x, **params)
# this throws an exception
losser.compute_marginals(predict, guide_for_inference,x, **params)

throws the exception


KeyError                                  Traceback (most recent call last)
<ipython-input-15-ae6e2d72ce18> in <module>
     14     losser.loss(multi_choice_single_sign_model, guide_for_inference,x.unsqueeze(0), **params, **model_hyperparameters_int)
     15     # this throws an exception
---> 16     losser.sample_posterior(multi_choice_single_sign_model, guide_for_inference,x.unsqueeze(0), **params, **model_hyperparameters_int)

/anaconda3/lib/python3.8/site-packages/pyro/infer/traceenum_elbo.py in sample_posterior(self, model, guide, *args, **kwargs)
    455         # TODO replace BackwardSample with torch_sample backend to ubersum
    456         with BackwardSampleMessenger(model_trace, guide_trace):
--> 457             return poutine.replay(model, trace=guide_trace)(*args, **kwargs)
    458 
    459 

/anaconda3/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/anaconda3/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

<ipython-input-2-09a2e5e01f78> in multi_choice_single_sign_model(data, transition_matrix, emission_logits, process_hands, hand_strategy, **kwargs)
     77 
     78     with pyro.plate("datapoints", num_datapoints):
---> 79         sign_idx = pyro.sample("sign",
     80                                distributions.Categorical(maybe_cuda(torch.tensor([1 / num_signs for _ in range(num_signs)]))))
     81         state = maybe_cuda(torch.zeros(num_datapoints).long())

/anaconda3/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    111             msg["is_observed"] = True
    112         # apply the stack and return its return value
--> 113         apply_stack(msg)
    114         return msg["value"]
    115 

/anaconda3/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    191         pointer = pointer + 1
    192 
--> 193         frame._process_message(msg)
    194 
    195         if msg["stop"]:

/anaconda3/lib/python3.8/site-packages/pyro/poutine/messenger.py in _process_message(self, msg)
    137         method = getattr(self, "_pyro_{}".format(msg["type"]), None)
    138         if method is not None:
--> 139             return method(msg)
    140         return None
    141 

/anaconda3/lib/python3.8/site-packages/pyro/infer/traceenum_elbo.py in _pyro_sample(self, msg)
    245         with shared_intermediates(self.cache):
    246             ordinal = _find_ordinal(self.enum_trace, msg)
--> 247             logits = contract_to_tensor(self.log_factors, self.sum_dims,
    248                                         target_ordinal=ordinal, target_dims={enum_symbol},
    249                                         cache=self.cache)

/anaconda3/lib/python3.8/site-packages/pyro/ops/contract.py in contract_to_tensor(tensor_tree, sum_dims, target_ordinal, target_dims, cache, ring)
    259     term = ring.sumproduct(contracted_terms, set())
    260     assert sum_dims.intersection(term._pyro_dims) <= target_dims
--> 261     return ring.broadcast(term, target_ordinal)
    262 
    263 

/anaconda3/lib/python3.8/site-packages/pyro/ops/rings.py in broadcast(self, term, ordinal)
     78             else:
     79                 missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
---> 80                 term = term.expand(missing_shape + term.shape)
     81                 dims = missing_dims + dims
     82                 self._cache[key] = term

/anaconda3/lib/python3.8/site-packages/pyro/ops/rings.py in <genexpr>(.0)
     78             else:
     79                 missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
---> 80                 term = term.expand(missing_shape + term.shape)
     81                 dims = missing_dims + dims
     82                 self._cache[key] = term

KeyError: 'a'

This only occurs when I’m trying to infer on a single datapoint (so there’s a pyro.plate("num_obs", 1")), so my current workaround is to drop the plate if there’s only one observation

Hi @npw that looks like a real bug, feel free to file an issue :grinning_face_with_smiling_eyes: