If your predictions are float-valued then there is almost zero probability of ever seeing the same point twice, hence mode won’t be useful for you. You probably need to trace
your model and look at the actual likelihood scores.
I’ll note that attempting to do this the naive way, using Predictive
, doesn’t appear to work. There are some ongoing issues with Predictive
that might be sorted eventually; it’s an experimental class anyway. You can “do this yourself”:
# example model
def model(data, size=1, verbose=False):
loc = pyro.sample('loc', dist.Normal(0.0, 1.0))
scale = pyro.sample('scale', dist.LogNormal(0.0, 1.0))
if verbose:
print(f"loc = {loc}, scale = {scale}")
with pyro.plate('plate', size=size):
obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data,)
return obs
data = model(None, size=10, verbose=True)
output$ loc = 1.2769572734832764, scale = 4.934558391571045
nuts = pyro.infer.NUTS(model)
mcmc = pyro.infer.MCMC(nuts, 1000)
mcmc.run(data, size=10)
samples = mcmc.get_samples()
# we actually did inference...
for site, values in samples.items():
print(f"{site}: mean = {values.detach().mean()}, std = {values.detach().std()}")
output$
loc: mean = 0.8348490595817566, std = 0.908659040927887
scale: mean = 6.1217875480651855, std = 1.4680243730545044
predictive = pyro.infer.Predictive(model, posterior_samples=samples)
# one would expect this to work...
traced_predictive = pyro.poutine.trace(predictive).get_trace(data, size=10)
output$
<very long traceback>
RuntimeError: Multiple sample sites named 'loc'
Trace Shapes:
Param Sites:
Sample Sites:
Trace Shapes:
Param Sites:
Sample Sites:
_num_predictive_samples dist |
value 1000 |
<more dimensions>
# ouch -- dimensionality error inside Predictive
# do it yourself using condition + trace
# this is what Predictive does anyway, but it tries to vectorize computations
for n in range(n_samples):
conditioned_model = pyro.poutine.condition(
model,
data={
"loc": samples['loc'][n],
"scale": samples['scale'][n],
}
)
tr_cond = pyro.poutine.trace(conditioned_model).get_trace(data, size=10)
tr_cond.compute_log_prob()
# get probs out of trace ds and store however you want...
< your code here ... >
This is kind of slow, but it does what you want. You can also do the simpler but “inexact” way – get your samples from Predictive
, make a histogram / kde, and just get the mode that way. If I were trying to do this for some reason, that’s what I’d do – but that’s just opinion.