Sigmoid Belief Network - How to handle Discrete Latent variables with Pyro

Tags: Discrete Latent Variables - Hybrid Latent Variables - Sigmoid Belief Network - Hybrid Guides

I am trying to construct a Sigmoid Belief Network (SBN) in Pyro. Two latent plates of variables (Z_1 and Z_2) explain an observed data set X. The weights are assumed to have a N(0,1) prior distribution and Z_1 follows a Bernoulli(0.5) prior distribution, Z_2 follows a Bernoulli(Z_1’W) distribution (note, elements of Z_2 have equal values for Z_1, but their own weight vector (W). Finally, elements of X follow a Bernoulli(Z_2’W) distribution.

Picforupload

I’m trying to rewrite the Sparse Gamma Deep Exponential Family example for this purpose, but Pyro is not able to handle latent discrete variables, according to an answer on yreddy’s earlier blogpost.

It remains, however, unclear to me what my code should look like. Currently my code looks as follows:
**
EDIT2: Added main programme to code

class SigmoidBeliefDEF(object):
    def __init__(self):
        # define the sizes of the layers in the deep exponential family
        self.top_width = 2
        self.bottom_width = 3
        self.data_size = 5
        # define hyperparameters that control the prior
        self.p_z = torch.tensor(0.5)
        self.mu_w = torch.tensor(0.0)
        self.sigma_w = torch.tensor(1.0)

    # 1
    # define the model
    def model(self, x):
        x_size = x.size(0)
        # 1.1
        # sample the global weights
        with pyro.plate("w_top_plate", self.top_width * self.bottom_width):
            w_top = pyro.sample("w_top", Normal(self.mu_w, self.sigma_w))
        with pyro.plate("w_bottom_plate", self.bottom_width * self.data_size):
            w_bottom = pyro.sample("w_bottom", Normal(self.mu_w, self.sigma_w))

        # 1.2
        # sample the local latent random variables
        # (the plate encodes the fact that the z's for different datapoints are conditionally independent)
        with pyro.plate("data", x_size):
            z_top = pyro.sample("z_top", Bernoulli(self.p_z).expand([self.top_width]).to_event(1))
            # note that we need to use matmul (batch matrix multiplication) as well as appropriate reshaping
            # to make sure our code is fully vectorized
            w_top = w_top.reshape(self.top_width, self.bottom_width) if w_top.dim() == 1 else \
                w_top.reshape(-1, self.top_width, self.bottom_width)
            mean_bottom = torch.sigmoid(torch.matmul(z_top, w_top))
            z_bottom = pyro.sample("z_bottom", Bernoulli(mean_bottom).to_event(1))

            w_bottom = w_bottom.reshape(self.bottom_width, self.data_size) if w_bottom.dim() == 1 else \
                w_bottom.reshape(-1, self.bottom_width, self.data_size)
            mean_obs = torch.sigmoid(torch.matmul(z_bottom, w_bottom))

            # observe the data using a Bernoulli likelihood
            pyro.sample('obs', Bernoulli(mean_obs).to_event(1), obs=x)

def main(args):
    dataset_path = Path(r"C:\Users\posc8001\Documents\DEF\Data\Simulation_1")
    file_to_open = dataset_path / "small_data.csv"
    f = open(file_to_open)
    data = torch.tensor(np.loadtxt(f, delimiter=',')).float()
    sigmoid_belief_def = SigmoidBeliefDEF()

    # Specify hyperparameters of optimization
    learning_rate = 0.2
    momentum = 0.05
    opt = optim.AdagradRMSProp({"eta": learning_rate, "t": momentum})

    # Specify the guide
    guide = AutoGuideList(sigmoid_belief_def.model)
    guide.add(AutoDiagonalNormal(poutine.block(sigmoid_belief_def.model,
                                               hide=["assignment"])))
    guide.add(AutoDiscreteParallel(poutine.block(sigmoid_belief_def.model,
                                                 expose=["assignment"])))
    guide = guide if args.auto_guide else sigmoid_belief_def.guide

    # Specify Stochastic Variational Inference
    svi = SVI(sigmoid_belief_def.model, guide, opt, loss=TraceMeanField_ELBO())

    # we use svi_eval during evaluation; since we took care to write down our model in
    # a fully vectorized way, this computation can be done efficiently with large tensor ops
    svi_eval = SVI(sigmoid_belief_def.model, guide, opt,
                   loss=TraceMeanField_ELBO(num_particles=args.eval_particles, vectorize_particles=True))

    # the training loop
    for k in range(args.num_epochs):
        loss = svi.step(data)

        if k % args.eval_frequency == 0 and k > 0 or k == args.num_epochs - 1:
            loss = svi_eval.evaluate_loss(data)
            print("[epoch %04d] training elbo: %.4g" % (k, -loss))

It is suggested to use poutine.block but it is unclear to me in which implementation I should use it in my code. Wrapping, for instance, the pyro.sample("z_top",[args]) part in a poutine.block call didn’t do the trick. Is there anyone who might help me out? Either by elaborating on the use of poutine.block or by proposing an alternative solution?

**
EDIT:
I decided to add the error that is returned to me upon running the programme.

Traceback (most recent call last):
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\torch\distributions\constraint_registry.py", line 139, in __call__
    factory = self._registry[type(constraint)]
KeyError: <class 'torch.distributions.constraints._Boolean'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\poutine\trace_messenger.py", line 147, in __call__
    ret = self.fn(*args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\contrib\autoguide\__init__.py", line 190, in __call__
    result.update(part(*args, **kwargs))
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\contrib\autoguide\__init__.py", line 377, in __call__
    self._setup_prototype(*args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\contrib\autoguide\__init__.py", line 326, in _setup_prototype
    self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\torch\distributions\constraint_registry.py", line 143, in __call__
    return factory(constraint)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\distributions\torch_distribution.py", line 226, in <lambda>
    biject_to.register(IndependentConstraint, lambda c: biject_to(c.base_constraint))
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\torch\distributions\constraint_registry.py", line 142, in __call__
    'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _Boolean constraints

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:/Users/posc8001/Documents/DEF/Scipio_DEF/sigmoid_belief_network.py", line 216, in <module>
    model = main(args)
  File "C:/Users/posc8001/Documents/DEF/Scipio_DEF/sigmoid_belief_network.py", line 193, in main
    loss = svi.step(data)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\svi.py", line 99, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\trace_elbo.py", line 125, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\elbo.py", line 163, in _get_traces
    yield self._get_trace(model, guide, *args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\trace_mean_field_elbo.py", line 75, in _get_trace
    model, guide, *args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\trace_elbo.py", line 52, in _get_trace
    "flat", self.max_plate_nesting, model, guide, *args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\enum.py", line 42, in get_importance_trace
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\poutine\trace_messenger.py", line 169, in get_trace
    self(*args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\poutine\trace_messenger.py", line 153, in __call__
    traceback)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\six.py", line 692, in reraise
    raise value.with_traceback(tb)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\poutine\trace_messenger.py", line 147, in __call__
    ret = self.fn(*args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\contrib\autoguide\__init__.py", line 190, in __call__
    result.update(part(*args, **kwargs))
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\contrib\autoguide\__init__.py", line 377, in __call__
    self._setup_prototype(*args, **kwargs)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\contrib\autoguide\__init__.py", line 326, in _setup_prototype
    self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\torch\distributions\constraint_registry.py", line 143, in __call__
    return factory(constraint)
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\distributions\torch_distribution.py", line 226, in <lambda>
    biject_to.register(IndependentConstraint, lambda c: biject_to(c.base_constraint))
  File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\torch\distributions\constraint_registry.py", line 142, in __call__
    'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _Boolean constraints
      Trace Shapes:        
       Param Sites:        
      Sample Sites:        
          data dist       |
              value 20000 |
w_bottom_plate dist       |
              value    15 |
   w_top_plate dist       |
              value     6 |

Process finished with exit code 1

EDIT3: Added tags

It’s hard to say for sure what’s going wrong since the code that actually produced the error isn’t included in your post, but it looks from the stack trace like you’re trying to use SVI with a continuous autoguide. The error is telling you that you can’t use a continuous variational distribution for a discrete latent variable. You’ll need to write your own guide or use a combination of a continuous and discrete autoguide (pyro.contrib.autoguide.AutoDiscreteParallel) via pyro.contrib.autoguide.AutoGuideList.

Note that inference in models like this with lots of discrete latent variables is difficult, especially variational inference, and if you want to work with a larger version of the model you may need to write your own guide and use pyro.infer.TraceGraph_ELBO with neural baselines to reduce gradient variance. See the SVI gradient estimator tutorial and one of the references there on amortized variational inference in SBNs for more background.

Hi @eb8680_2,

Thanks for your fast reply. I updated my question to include the main programme.

Thank you as well for pointing out the AutoGuideList option. I forgot to mention it in my initial question, but I attempted to implement it. It still returns the aforementioned error though. Did I use a wrong implementation? (see my implementation below)

# Specify the guide
guide = AutoGuideList(sigmoid_belief_def.model)
guide.add(AutoDiagonalNormal(poutine.block(sigmoid_belief_def.model,
                                           hide=["assignment"])))
guide.add(AutoDiscreteParallel(poutine.block(sigmoid_belief_def.model,
                                             expose=["assignment"])))
guide = guide if args.auto_guide else sigmoid_belief_def.guide

As for the other pointers, thank you very much. I’ll look into them today and get back at you once I have found the time!

Best.