Pyro seems to encourage the use of effect handlers and, indeed, I generally prefer to use poutine.condition
rather than the keyword argument obs
in pyro.sample
.
Typical use cases where this is fantastic are to quickly convert a generative model into a conditioned model for inference or to compare a model where a variable is assumed known vs one where it is latent.
However, poutine.condition
does not seem to be compatible with subsampling (e.g. to train using mini-batches) and when I need to use mini-batch training I am forced to change framework and use obs
.
To clarify what I mean, consider the code below, it does not return a valid model that can be used for inference.
def generative_model():
feature_plate = pyro.plate("feature_plate", size=90, dim=-1)
sample_plate = pyro.plate("sample_plate", size=1000, subsample_size=20, dim=-2)
mu = pyro.sample("mu", pyro.distributions.Normal(0, 1))
with feature_plate:
with sample_plate:
X = pyro.sample("X", pyro.distributions.Normal(mu, 1))
Xdata = pyro.distributions.Normal(0, 1).sample((1000,90))
inference_model = poutine.condition(generative_model, data={"X": Xdata})
Is there a way to use the effect handler logic to condition a model (i.e. poutine.condition
) still allow mini-batch training (automatic or custom)?
I was thinking a way of doing it that would be consistent with pyro’s effect handler composability would be:
inference_model = poutine.condition(generative_model, data={"X": pyro.subsample(Xdata, 0)})
That of course does not work because the. pyro.subsample
statement is not inside a pyro.plate
.
Is there a solution right now or am I going in a “feature request” direction with this topic?