[MCMC] [Discrete RV] [Parallelization] Is there anyway to stop pyro automatically vectorizing tensors

Hi guys, I was trying to adapt this program into pyro. But I got the following error from the following code.

Seems like my code is incompatible with Pyro’s auto parallelization for discrete random variables.
It’s there any way to make pyro stop automatically vectorizing tensors in my code? Or is this a bug of Pyro?

Thanks in advance.

~/miniconda3/envs/pyro-ppl-1.8.0/bin/python ~/Projects/sprinkler/scripts/training.py --lr 0.005 --num_hidden_units 64 --num_inference_samples 100 --training_batch_size 16 --n_steps 0 --guide guide_fidelia --validation_batch_size 10
Warmup:   0%|          | 0/2000 [00:00, ?it/s]Traceback (most recent call last):
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "~/Projects/sprinkler/scripts/model.py", line 13, in model
    sprinkler = P.sample("sprinkler", PD.Bernoulli(sprinkler_get_params(cloudy)),obs=observations["sprinkler"])
  File "~/Projects/sprinkler/scripts/priors.py", line 5, in sprinkler_get_params
    return 0.1 if cloudy else 0.5
RuntimeError: Boolean value of Tensor with more than one value is ambiguous

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "~/Projects/sprinkler/scripts/training.py", line 124, in <module>
    mcmc.run(observations=observations)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 563, in run
    for x, chain_id in self.sampler.run(*args, **kwargs):
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 223, in run
    for sample in _gen_samples(
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 144, in _gen_samples
    kernel.setup(warmup_steps, *args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/infer/mcmc/hmc.py", line 325, in setup
    self._initialize_model_properties(args, kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/infer/mcmc/hmc.py", line 259, in _initialize_model_properties
    init_params, potential_fn, transforms, trace = initialize_model(
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/infer/mcmc/util.py", line 434, in initialize_model
    model_trace = prototype_model.get_trace(*model_args, **model_kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "~/miniconda3/envs/pyro-ppl-1.8.0/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "~/Projects/sprinkler/scripts/model.py", line 13, in model
    sprinkler = P.sample("sprinkler", PD.Bernoulli(sprinkler_get_params(cloudy)),obs=observations["sprinkler"])
  File "~/Projects/sprinkler/scripts/priors.py", line 5, in sprinkler_get_params
    return 0.1 if cloudy else 0.5
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
Trace Shapes:    
 Param Sites:    
Sample Sites:    
  cloudy dist   |
        value 2 |

Process finished with exit code 1
import pyro as P
import pyro.distributions as PD
import torch as T
from priors import *
from pyro.infer import MCMC, NUTS

def cloudy_get_params():
    return 0.5

def sprinkler_get_params(cloudy):
    return 0.1 if cloudy else 0.5

def rain_get_params(cloudy):
    return 0.8 if cloudy else 0.2

def wetgrass_get_params(springkler, rain):
    if springkler and rain:
        return 0.99
    elif springkler and not rain:
        return 0.9
    elif not springkler and rain:
        return 0.9
    else:
        return 0.0

def model(
        observations={
            "sprinkler" : T.tensor([]),
            "wetgrass"  : T.tensor([])
        }
):
    cloudy    = P.sample("cloudy", PD.Bernoulli(cloudy_get_params()))
    sprinkler = P.sample("sprinkler", PD.Bernoulli(sprinkler_get_params(cloudy)),obs=observations["sprinkler"])
    rain      = P.sample("rain", PD.Bernoulli(rain_get_params(cloudy)))
    wetgrass  = P.sample("wetgrass", PD.Bernoulli(wetgrass_get_params(sprinkler,rain)),obs=observations["wetgrass"])

    return rain

if __name__ == '__main__':
    observations = {"sprinkler" : T.tensor(1.0), "wetgrass" : T.tensor(1.0)}

    nuts_kernel = NUTS(model.model, jit_compile=False)
    mcmc = MCMC(
        nuts_kernel,
        num_samples=1000,
        warmup_steps=1000,
        num_chains=1
    )
    mcmc.run(observations=observations)
    mcmc.summary(prob=0.5)
    samples = mcmc.get_samples()

Your model appears to be entirely discrete, and therefore is not amenable to NUTS or HMC inference. Pyro does not provide prepackaged MCMC algorithms for entirely discrete models.

Among Pyro’s prepackaged inference algorithms that might work for your model (TraceEnum_ELBO, SMC), all require vectorization. I would recommend either rewriting your model to handle vectorization or using a different library that is less focused on vectorized inference.