Achieving `observe` behavior

Consider the definition of the observe statement from Probabilistic Programming, as defined in [1]:

The observe statement blocks runs which do not satisfy the boolean expression E and does not permit those executions to happen.

Now, consider the following theoretical program:

def f():
    x ~ Normal(0, 1)
    observe(x > 0) # only allow samples x > 0
    return x

which should return values from the truncated Normal(0, 1) distribution.

Therefore, my question is: how can observe be implemented in Pyro, or what’s its equivalent?
Of course, observe's argument can be any boolean expression E.


[1] Gordon, Andrew D., et al. “Probabilistic programming.” Proceedings of the on Future of Software Engineering . 2014. 167-181.

Hi @AlexD,

Pyro implements only a restricted form of observe statements: observations must be point values drawn from a distribution. This restriction makes it easier to implement efficient inference algorithms, at the cost of limiting expressivity.

Some workarounds for specific cases include:

  • observing a noisy boolean value, e.g.
    def f():
        x = pyro.sample("x", dist.Normal(0, 1))
        pyro.sample("obs", dist.Bernoulli(0.99),
                    obs=(x > 0).float())
        return x
    
  • using custom distributions, e.g.
    def f():
        x = pyro.sample("x", dist.HalfNormal(1))
        return x
    

The Pyro team is also actively researching ways to do more symbolic conditioning, this is the aim of the Funsor library.

Let us know if you have particular applied probabilistic models that would like to translate into Pyro’s restricted modeling language, we’d be happy to help.

Best,
@fritzo

1 Like

Hi @fritzo,

Thank you very much for your reply!
I have two more questions.


First, regarding this example:

def f():
    x = pyro.sample("x", dist.Normal(0, 1))
    pyro.sample("obs", dist.Bernoulli(0.99), obs=(x > 0).float())
    return x

how is x affected by the second sample of obs?

Furthermore, if you change it to:

def f():
    x = pyro.sample("x", dist.Normal(0, 1))
    y = pyro.sample("obs", dist.Bernoulli(0.99), obs=(x > 0).float())
    return y

this will just return (x > 0).float(), ignoring any dist given to sample().


Second, if my understanding is correct, sampling from a truncated gaussian is achieved using HalfNormal, there is (currently) no possible way to do so via observations (i.e. conditioning), right?

First, x is affected by the obs site by being forced to be “probably positive”, which is strictly weaker than “positive”. This means that when fitting model parameters, it will be gently discouraged to generate negative x, rather than strictly prohibited.

Furthermore, correct: observe statements return their obs= value regardless of distribution; however when fitting model parameters that distribution does matter.

1 Like

Continuing this discussion, I am experimenting with conditioning and wanted to try this inference example, that may simulate observe(x > 0) (it’s just a toy example).

Suppose I have the following conditioned model and guide:

def model():
    x = pyro.sample('x', dist.Uniform(-1, 1))
    
    if x <= 0:
        y = pyro.sample('y', dist.Bernoulli(0.0))
    else:
        y = pyro.sample('y', dist.Bernoulli(1.0))
        
    return y

cond_model = pyro.condition(model, data={'y': T.tensor(1.0)})

def guide():
    a = pyro.param('a', T.tensor(-1.0))
    b = pyro.param('b', T.tensor(1.0))
    return pyro.sample('x', dist.Uniform(a, b))

Now, is my understanding correct that observing y = 1 will imply that x was sampled from the positive half of Uniform(-1,1), that is Uniform(0,1).

So, is it possible, by inference, to get (a,b) ≈ (0,1)?

Yes, in principle it is possible to learn posterior interval (a,b) ≈ (0,1) on x, however I’m not sure any of Pyro’s inference algorithms will work well here, possibly Importance sampling with a good prior?

1 Like

Thank you, @fritzo! This is clear.