Getting the predictions with the highest probabilty

Hi,

I am learning to use Predictive. I would like to get the predictions with the highest probability.

After MCMC inference (using y_train and x_train), with my model ( y = model(x)), like this

posterior = mcmc.run(x_train, y_train)
posterior_samples = mcmc.get_samples()

I use the posterior samples like this:
p = Predictive(model, posterior_samples)

Then, I sample the posterior using some test data x_test as input for the model;

predictions = p(x_test)['obs']

This gives me a tensor with shape [nb_posterior_samples, nb_x_test_inputs]. I understand that the first dimension of this tensor gives me a set of samples as output of my model for each input test point.

Finally, I get the min, max and mean value of the predictions for each input sample:

max_predictions = torch.max(predictions, 0).values
min_predictions = torch.min(predictions, 0).values
mean_predictions = predictions.mean(axis=0)

I would also like to get the prediction with the highest probability for each input sample. I have tried to use the mode:
mode_predictions = torch.mode(predictions, 0).values

but it does not yield the expected results.

What is the appropriate way of doing this?

Thanks

If your predictions are float-valued then there is almost zero probability of ever seeing the same point twice, hence mode won’t be useful for you. You probably need to trace your model and look at the actual likelihood scores.

I’ll note that attempting to do this the naive way, using Predictive, doesn’t appear to work. There are some ongoing issues with Predictive that might be sorted eventually; it’s an experimental class anyway. You can “do this yourself”:

# example model
def model(data, size=1, verbose=False):
    loc = pyro.sample('loc', dist.Normal(0.0, 1.0))
    scale = pyro.sample('scale', dist.LogNormal(0.0, 1.0))
    if verbose:
        print(f"loc = {loc}, scale = {scale}")
    with pyro.plate('plate', size=size):
        obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data,)
    return obs

data = model(None, size=10, verbose=True)
output$ loc = 1.2769572734832764, scale = 4.934558391571045

nuts = pyro.infer.NUTS(model)
mcmc = pyro.infer.MCMC(nuts, 1000)
mcmc.run(data, size=10)

samples = mcmc.get_samples()
# we actually did inference...
for site, values in samples.items():
    print(f"{site}: mean = {values.detach().mean()}, std = {values.detach().std()}")

output$ 
loc: mean = 0.8348490595817566, std = 0.908659040927887
scale: mean = 6.1217875480651855, std = 1.4680243730545044

predictive = pyro.infer.Predictive(model, posterior_samples=samples)
# one would expect this to work...
traced_predictive = pyro.poutine.trace(predictive).get_trace(data, size=10)
output$ 
<very long traceback> 
RuntimeError: Multiple sample sites named 'loc'
Trace Shapes:
 Param Sites:
Sample Sites:
               Trace Shapes:       
                Param Sites:       
               Sample Sites:       
_num_predictive_samples dist      |
                       value 1000 |
<more dimensions>

# ouch -- dimensionality error inside Predictive
# do it yourself using condition + trace
# this is what Predictive does anyway, but it tries to vectorize computations
for n in range(n_samples):
    conditioned_model = pyro.poutine.condition(
        model,
        data={
            "loc": samples['loc'][n],
            "scale": samples['scale'][n],
        }
    )
    tr_cond = pyro.poutine.trace(conditioned_model).get_trace(data, size=10)
    tr_cond.compute_log_prob()
    # get probs out of trace ds and store however you want...
    < your code here ... >

This is kind of slow, but it does what you want. You can also do the simpler but “inexact” way – get your samples from Predictive, make a histogram / kde, and just get the mode that way. If I were trying to do this for some reason, that’s what I’d do – but that’s just opinion.

Hi,

Thanks for your detailed answer. I think I will go the histogram way to get the value of the mode, even if it is inexact.

Otherwise, for plotting purposes, given that the data are monomodal, I have used arviz with a very small HDI like this:

az.plot_hdi(xdata, pyro_data.posterior_predictive[‘obs’], hdi_prob = 0.01)

A more accurate way to estimate a MAP point would be to separately use SVI with an AutoDelta guide. Roughly

model = ...
guide = AutoDelta(model)
optim = ...
svi = SVI(model, guide, optim, Trace_ELBO())
for step in range(num_steps):
    svi.step()

map_estimate = guide()  # a dict from sample site name to value

If you have pyro.deterministic sites in the model you can additionally record those via

trace = poutine.trace(poutine.condition(model, map_estimate)).get_trace()