Equivalent of `numpyro.substitute` in pyro?

Hi,

I wonder how to substitute parameters with deterministic known values in Pyro? In numpyro, I can achieve this with

model_with_fixed_param = numpyro.handlers.substitute(model, {"param0": 1.0})

afaik there is no direct analog. substitute in numpyro was made necessary by jax’s functional style.

what is your use case?

I am trying to generate data from the pyro model itself, which means that I need to fix the parameters to certain values. Is there a way to achieve that in Pyro?

are they parameters or latent variables? if the latter you can use condition

They are parameters not latent variables.

Currently I am using lift to sample parameters from dist.Delta… Not sure how efficient this is. It would be ideal for pyro to provide a substitute alternative for ancestral sampling from the model!

can you please describe your specific use case in detail? e.g. where do these parameters that you’re substituting come from? do they come from an elbo-based training procedure? training with the same model? etc

import torch
import pyro
from pyro import distributions as dist
from pyro import poutine
from torch.distributions import constraints


def model(N):
	a = pyro.param("a", 
			init_tensor=torch.tensor([0.]), 
			constraint=constraints.unit_interval)
	with pyro.plate("N", N):
		x = pyro.sample("x", dist.Bernoulli(probs=a))
	return x


def generate(rng_seed, N, a):
	with poutine.seed(rng_seed=rng_seed):
		# Is there a better alternative than below?
	    return poutine.lift(
	        model,
	        prior={
	            "a": dist.Delta(a),
	        })(N)

# sample with a = 0.3
x = generate(0, 10, torch.tensor(0.3))
print(x)

# sample with a = 0.8
x = generate(0, 10, torch.tensor(0.8))
print(x)

The use case is separate from any particular inference algorithm (or ELBO etc.).
See the example script above: all I would like is a way to forward sample from the model with some fixed parameters.

The particular use case I have is to generate artificial data from the model. Then, I will perform inference on these simulated data to evaluate the performance of various inference algorithms.
This avoids the need to write a separate generative model outside of pyro.

Thanks

why can’t you do something like this:

def model(N, a=None):
	a = pyro.param("a", init_tensor=torch.tensor([0.]), 
			          constraint=constraints.unit_interval) if a is None else a

use model(N, a=None) for e.g. elbo-based training and model(N, a=a_val) to substitute in specific values of a

I have never thought of this — Seems reasonable. But it would be cumbersome if there are many parameters?
It seems better (IMO) to have a substitute helper.
Is there a reason to not have such support? (maybe the implementation would be complicated)

Nonetheless, thanks for the tip

please make a github issue for a feature request. i imagine it’s not too hard to implement but it might be a while before someone gets around to it; hard to say