How should one do inference when conditioning on a Delta?

I took the toy model from the SVI tutorial…

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0)) 
    measurement = pyro.sample("measurement", dist.Normal(weight, 0.75))
    return measurement

conditioned_scale = pyro.condition(scale, data={"measurement": torch.tensor(9.5)})

and modified it to be the following:

def scale(guess):
    weight = pyro.sample("weight", Normal(guess, 1.0))
    Z = pyro.sample('Z', Normal(0.0, 1.0))
    measurement = pyro.sample("measurement", Delta(weight + Z*0.75))
    return measurement

conditioned_scale = pyro.condition(scale, data={"measurement": torch.tensor(9.5)})

I then tried to infer weight with SVI, HMC, and Importance sampling. I anticipated failure trying these methods with a conditioned Delta (though I was surprised at the error messages), so I also tried with an approximation to the Delta I call Spike (as in spike and slab) i.e. a Normal with very small scale. This also generally failed, except when the scale parameter was so large that you really couldn’t call it an approximate Delta anymore. I suspect HMC would work with Spike if I increased max tree depth for NUTS, but haven’t tried yet.

Everything I tried I stuck in a Google colab notebook.

I haven’t yet tried just going for MLEs or MAPs with Adam optimizer – I’d be fine with that, I don’t need posterior sampling for my problem.

Any advice on how to get inference to work in this setting? There is a class of models I want to work with where I need to condition on Deltas or Spikes.

Hi @osazuwa, you can’t currently condition or observe a Delta in Pyro. We’re working hard on supporting this in the next major version of Pyro via delayed sampling, but this won’t be ready for a few months.

In the short term I can only recommend hand-implementing delayed sampling, which is essentially the first version of your model.

1 Like

Hi @fritzo – I wanted to follow up on this and ask if this was being worked on still. I don’t see it in 0.4.0 or 0.4.1.

(For example, there are classes of models where the observed outcome is a deterministic function of latent random variables. If you want to incorporate observed information, you can’t because you can’t condition on anything but a pyro.sample statement. If you have ideas of how to model this with the existing capabilities, let me know.)

@logan For now you can hand-implement this via TransformedDistribution. For example @osazuwa’s example can be rewritten as

from torch.distributions import AffineTransform, TransformedDistribution

def scale(guess):
    weight = pyro.sample("weight", Normal(guess, 1.0))
    z_dist = Normal(0, 1)
    m_dist = TransformedDistribution(
        z_dist, AffineTransform(weight, torch.tensor(0.75)))
    measurement = pyro.sample("measurement", m_dist)
    return measurement

For more complex bijective transforms you can implement your own Transform class by providing forward, inverse, and log_abs_det_jacobian implementations.

Beyond that we’re still working on delayed sampling in Pyro. The effort involves a mixed numerical-symbolic backend for Pyro. This development is taking place in the funsor repo. Eventually we hope to be able to run @osazuwa’s original code inside a context with pyro_backend('funsor').

Hi @fritzo, has there been any further progress on this? I see there is mention of delayed sampling in the funsor README, but are there perhaps any working examples in a pyro implementation, much like @osazuwa’s question ? I tried implementing the funsor backend with a context manager but there does not seem to be any change in behaviour.

Hi @Rei, we are still pretty far from being able to automatically condition on a Delta in Pyro. I would recommend transforming your model by hand or trying to rewrite your model using a TransformedDistribution.

Thanks for the fast reponse @fritzo, will do :slightly_smiling_face:

Hi @fritzo, I just wanted to check in on this with the progress that has been made on the funsor backend. Is it yet possible to condition/observe on a Delta?

Hi @Rei, could you open a github issue with your particular model or an example snippet of model code?

We now have an effect handler poutine.collapse that can in theory collapse Delta statements, but in practice we’re probably missing patterns to rewrite models. It will help us get this working to see some examples of transformations that you and others would like to use. Ideally we can add patterns as unit tests like these.

Hi @fritzo, I opened an issue here with an example. Please let me know if it is adequate as an example and if there is anything else that I can provide…

1 Like