Usage of deterministic

Hi, I am new to Pyro, and I am currently trying to work out examples in this repo: using pyro.

As I am working on the example of “An algorithm for human deceit” in I have following model setup:

import pyro, torch, pyro.distributions as dist
N = 100
X = 35
def model(X):
    p = pyro.sample('freq_cheating', dist.Uniform(0, 1))
    with pyro.plate("truth_ans", N):
        true_answers = pyro.sample('truths', dist.Bernoulli(p))
        first_coin_flips = pyro.sample('first_flips', dist.Bernoulli(0.5))
        second_coin_flips = pyro.sample("second_flips", dist.Bernoulli(0.5))
        val = pyro.deterministic('val', first_coin_flips*true_answers + (1 - first_coin_flips)*second_coin_flips)
    observed_proportion = pyro.deterministic("observed_proportion", torch.sum(val) / N)
    observations = pyro.sample("obs", dist.Binomial(total_count=N, probs=observed_proportion), obs=X)
pyro.render_model(model, model_args=(torch.tensor(X), ))

this gives following graph:

which looks weird to me because val and observed_proportion seem to be disconnected from other variables. After I ran following sampling code:

kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, max_tree_depth=3)
posterior = MCMC(kernel, num_samples=25000, warmup_steps=15000);
hmc_samples = {k: v.detach().cpu().numpy() for k, v in posterior.get_samples().items()}
figsize(12.5, 3)
p_trace = hmc_samples["freq_cheating"][15000:]
plt.hist(p_trace, histtype="stepfilled", density=True, alpha=0.85, bins=30, 
         label="posterior distribution", color="#348ABD")
plt.vlines([.05, .35], [0, 0], [5, 5], alpha=0.3)
plt.xlim(0, 1)

I found the posterior distribution of freq_cheating is still uniform distribution, meaning the model does not seem to learn anything.

I am wondering if I did something wrong with deterministic, which somehow broke my computation graph, does anyone have any idea?

Thank you!

NUTS does not work on models with discrete latent variables. pyro first sums out the discrete latent variables and then does NUTS on the remaining continuous latent variables. to do this you need to write your model in a parallelizable fashion. for example i think you may need torch.sum(val, dim=-1). see e.g. here for details.

you could also try using the MixedHMC implementation in numpyro instead.