Error with NUTS when random variable is used in control flow

I’ve implemented a model (shown below) where the control flow through the program depends on a draw from a Bernoulli random variable. For some reason, when I attempt to do inference using MCMC with a NUTS kernel I get the following runtime error.

RuntimeError: bool value of Tensor with more than one value is ambiguous
Trace Shapes:         
 Param Sites:         
Sample Sites:         
    edge dist        |
        value 2    1 |
   slope dist        |
        value        |
       n dist        |
        value   1000 |

Here is the code that generated the error.

import pyro
import pyro.distributions as dist

from pyro import sample
from pyro.infer.mcmc import MCMC, NUTS

def model(n):
    theta_edge = 0.5
    X_mean = 0
    X_var = 1
    Y_mean = 0
    Y_var = 1
    
    # edge == 1 means Y is a linear function of X with additive gaussian noise.
    
    edge = sample('edge', dist.Bernoulli(theta_edge))
    slope = sample('slope', dist.Normal(0,1))

    with pyro.plate('n', n):
        if edge:
            X = sample('X', dist.Normal(X_mean, X_var))
            Y = sample('Y', dist.Normal(Y_mean + X * slope, Y_var))
        else:
            Y = sample('Y', dist.Normal(Y_mean, Y_var))
            X = sample('X', dist.Normal(X_mean + Y * slope, X_var))
    
    return X, Y, edge

n = 1000
n_steps = 5000
X, Y, edge = model(n)

conditioned_model = pyro.condition(model, data={'X':X, 'Y':Y})
nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)

posterior = MCMC(nuts_kernel, num_samples=n_steps, warmup_steps=200).run(n)

The inference works fine when I use ImportanceSampling. I suspect that this has something to do with the fact that MCMC is expressing edge as a tensor, while importance sampling samples from the prior program without modifying their underlying data type. Is there a way to get NUTS (or any other MCMC method) to work with this kind of program?

I’m happy to additional detail from the error message if that would be helpful.

Thanks!

NUTS cannot be directly used to do inference in models with discrete latent parameters, but what we can do is to integrate out the discrete variables. Pyro does that behind the scenes when it sees a discrete variable in the model. As such, you should try to eliminate the if.. else.. condition (which won’t work because the edge variable is already handling the branching by sampling both 0 and 1, and doesn’t hold a single value).

Not sure about the correctness of the model, and this may fail some other shape checks, but the following code might give you an idea about how to proceed (which I think is equivalent to your conditional statement):

X = sample('X', dist.Normal(X_mean + Y * slope * (1 - edge), X_var)
Y = sample('Y', dist.Normal(Y_mean + X * slope * edge, Y_var )

Thanks for the note about NUTS. As an aside, is there a quick-and-easy way to use proposals sampled from the prior instead of using the NUTS or HMC kernels?

Unfortunately, your code snippet fails because the variable Y is referenced before assignment. I don’t see a way around this, but I’m certainly appreciative of the suggestion.

As an aside, is there a quick-and-easy way to use proposals sampled from the prior instead of using the NUTS or HMC kernels?

The simplest way would be to just run your model as follows which should give you traces from the prior. That should contain all the samples from your sites stored within the trace data structure.

from pyro import poutine

traces = []
for _ in range(100):
    traces.append(poutine.trace(model).get_trace(n))

Or you could define your own “prior kernel” like in this test, though I am unsure how useful that would be. Note that this wouldn’t replicate the way discrete sites are handled inside HMC/NUTS and your sample shapes from the trace will be different.

Unfortunately, your code snippet fails because the variable Y is referenced before assignment. I don’t see a way around this, but I’m certainly appreciative of the suggestion.

Sorry, you are right about that! Can you introduce an auxiliary variable as follows (I think that shouldn’t change the result of inference)?:

Y_ = sample('Y_', dist.Normal(Y_mean, Y_var)
X = sample('X', dist.Normal(X_mean + Y_ * slope * (1 - edge), X_var)
Y = sample('Y', dist.Normal(Y_mean + X * slope * edge, Y_var )
1 Like

Neat! I have thought that it is not doable. Another way might be to interpolate between 2 forms whose sequential relation between X and Y is removed.

# edge = 1
X = Normal(X_mean, X_std)
Y = Normal(Y_mean + X_mean * slope, sqrt(Y_var + X_var * slope^2))
# edge = 0
X = Normal(X_mean + Y_mean * slope, sqrt(X_var + Y_var * slope^2))
Y = Normal(Y_mean, Y_std)

which implies

X = Normal(X_mean + Y_mean * slope * (1 - edge),
    X_std * edge + sqrt(X_var + Y_var * slope^2) * (1 - edge))
Y = Normal(Y_mean + X_mean * slope * edge,
    Y_std * (1 - edge) + sqrt(Y_var + X_var * slope^2) * edge)

Thanks for the additional help!

One thing I notice about your auxiliary variable suggestion is that if edge == 0 and we return X, Y, edge, then we’ll essentially be returning samples from the correct marginal distributions over X and Y, but these samples will incorrectly be independent. Instead, we’d need to return X, Y, Y_, and then determine later which of Y or Y_ should be used in the conditioning statement based on the sampled value of edge.

A related concern is that we have to condition on Y and Y_ simultaneously, as our program now considers them as two separate random variables. Besides aesthetics and perhaps some additional computation, my intuition is that these shouldn’t be major issues.

I implemented these changes, and NUTS now works! Unfortunately, the inference using NUTS is now performing significantly worse than my original model using ImportanceSampling. In fact, it’s placing all of the posterior mass on the incorrect assignment of edge, whereas the ImportanceSampling solution was always correct. Does anything stand out as strange with this alternative implementation based on your suggestions?

def alt_model(n):
    theta_edge = 0.5
    X_mean = 0
    X_var = 1
    Y_mean = 0
    Y_var = 1
    
    # edge == 1 means X causes Y.
    
    edge = sample('edge', dist.Bernoulli(theta_edge))
    slope = sample('slope', dist.Normal(0,1))
    with pyro.plate('n', n):
        Y_ = sample('Y_', dist.Normal(Y_mean, Y_var))
        X = sample('X', dist.Normal(X_mean + Y_ * slope * (1 - edge), X_var))
        Y = sample('Y', dist.Normal(Y_mean + X * slope * edge, Y_var))
    
    return X, Y, Y_, edge
    
n = 1000
n_steps = 1000
X, Y, Y_, edge = alt_model(n)

if edge:
    conditioned_model = pyro.condition(alt_model, data={'X':X, 'Y':Y, 'Y_':Y})
else:
    conditioned_model = pyro.condition(alt_model, data={'X':X, 'Y':Y_, 'Y_':Y_})

nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)

posterior = MCMC(nuts_kernel, num_samples=n_steps, warmup_steps=200).run(n=n)

I like the idea, but it looks like this solution would produce variables X and Y that are independent, while having the correct marginal distributions. Unless I’m understanding incorrectly, an individual draw from Y doesn’t depend on the corresponding value of X in your proposed program.

Could you possibly clarify what the alternative program would like using the full pyro syntax?

Thanks!

@switty Yes, you are right that they are conditionally (on edge and slope) independent. I have removed the sequential relation so it is not exactly the same model as yours. :slight_smile:

I am not sure if NUTS is doing anything since you are using condition to fix all the sample sites (so there isn’t any posterior distribution to learn). Maybe you want to pass in observed X and do inference over Y or something like that? Could you paste what you are doing with the importance sampler for comparison?

The intention is to infer the edge and slope variables, which are not included in the condition statement. This is very much like the canonical bayesian linear regression example, where a model is specified such that parameters are sampled once outside of a plate and then individual instances are repeatedly sampled within a plate as a function of those parameters.

Using the model function in my original post, the importance sampling code is the following.

n = 1000
n_steps = 5000

X, Y, edge = model(n)

conditioned_model = pyro.condition(model, data={'X':X, 'Y':Y})
is_posterior = pyro.infer.Importance(conditioned_model, num_samples=n_steps).run(n)
is_marginal = pyro.infer.EmpiricalMarginal(is_posterior, 'edge')
is_samples = [is_marginal().detach() for _ in range(1000)]
print(edge, sum(is_samples)/len(is_samples))

I see. Sorry, I completely missed that we had “slope” in there. Since edge is enumerated over, it isn’t currently possible to do inference over edge but you do get a distribution over slope. I checked that the mean of this distribution is -0.4.

In fact, it’s placing all of the posterior mass on the incorrect assignment of edge , whereas the ImportanceSampling solution was always correct.

Could you elaborate on what you mean by this? Note that the values of the other sample (X, Y) sites are simply what you have conditioned on.

Aha, thanks for pointing out that the inference procedure is enumerating over the discrete variables. If this doesn’t really reflect the posterior, how should I interpret the marginal distribution over edge after running MCMC? If I’m particularly interested in the posterior over these kinds of discrete variables, what are some other options?

Sorry if that statement was vague. Executing the model programs returns a single set of samples for X, Y, and edge. However, I then only use X and Y to condition the program. The intention here is to compare the posterior P(edge | X, Y) to the ground truth edge variable that I originally sampled. When I said the ImportanceSampling solution was always correct, what I meant was that the approximate posterior distribution over the edge variable using Importance Sampling, P(edge=1 | X, Y), is very close to 1.0 if and only if the ground truth edge sample is 1, and very close to 0.0 otherwise. In other words, the latent edge can be recovered using only X and Y. This is great for now, but I don’t expect ImportanceSampling to scale to more complicated versions of this problem.

I hope this is more clear and I really appreciate your help here!

If I’m particularly interested in the posterior over these kinds of discrete variables, what are some other options?

I think one option would be to use Variational Inference. The gmm example is useful in understanding how SVI tackles enumeration and allows you to do MAP inference over the discrete latent variables.

In this case, much more simply (though I am unsure if this serves your purpose), you could also place a prior over theta_edge, and look at what the posterior over theta_edge looks like:

    theta_edge = sample('theta_edge', dist.Beta(1., 1.))
    edge = sample('edge', dist.Bernoulli(theta_edge))

I would also suggest amplifying the difference between X_mean and Y_mean and maybe fixing the slope to some constant for initial debugging of the model.