Two quick questions related to pyro.deterministic:
Using deterministic primitives leads to lots of instances of RuntimeWarning when sampling from the model outside of an SVI instance (e.g. below). Is there a way to avoid these (aside from warnings.filterwarnings())?
Are there side effects from including deterministic primitives with the same name in the model and guide? I couldn’t quite get my head around what the implications of replaying multiple observed sample statements against one another are.
Using deterministic primitives leads to lots of instances of RuntimeWarning
Thanks for reporting, this is just a bug. pyro.deterministic is fairly new and we haven’t worked out all the edge cases. We’ll remove the RuntimeWarning before next release :
Are there side effects from including deterministic primitives with the same name in the model and guide? I couldn’t quite get my head around … the implications
Hmm I think it should be unnecessary to include pyro.deterministic statements in the guide, and I’d expect Pyro to error in such a case. Can you give an example of when you’d like to include a pyro.deterministic statement in a guide?
Re the use of pyro.deterministic in the guide, it came from my other slightly confused question about ways to reuse code in model(). As you say, I think it is unnecessary, though Pyro doesn’t error.
Edit: Ok, seems to work You have to call pyro.sample on the final observable variable in the guide with is normaly not needed (or even no recomended).
Update: It turns out that deterministic does not work when passed as mean for Normal (did not check on other distributions). For that purpose I replaced it with pyro.sample from Delta (needed in the model and in the guide). @fritzo can you comment on this?
I think deterministic should work in recent releases. There are some problems with your model/guide:
I think obs should not appear in guide
If you want to get values of sigma, which is transformed from multiplier, you should replace multiplier=2 by multiplier = pyro.param('multiplier', ...) in your model. You can use Predictive to get values of any sites that you want, including deterministic sites.
Without obs in guide, multiplier does not optimize.
I dont’t want to get values of sigma - it is obvious how to obtain it. I want to pass pyro.deterministic to Normal and optimize parameters. And it does not work.
Input:
def model(x, y=None):
multiplier = 2
sigma = pyro.sample("sigma", dist.Exponential(1))
mean = pyro.deterministic("mean", torch.abs(x*multiplier) + 1)
with pyro.plate("data", x.shape[0]):
return pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
def guide(x, y=None):
sigma_rate = pyro.param("sigma_rate", torch.Tensor([3]))
multiplier = pyro.param('multiplier', torch.Tensor([7]))
mean = torch.abs(x*multiplier) + 1
sigma = pyro.sample("sigma", dist.Exponential(sigma_rate))
with pyro.plate("data", x.shape[0]):
return pyro.sample("obs", dist.Normal(mean, sigma))
x = torch.distributions.Bernoulli(0.6).sample((100,))
y = model(x)
pyro.clear_param_store()
svi = SVI(model, guide, pyro.optim.Adam({"lr": 1e-3}), loss=Trace_ELBO())
for _ in range(1000):
elbo = svi.step(x, y)
dict(pyro.get_param_store())
Could you try that? Currently, multiplier is the constant 2 in your model, hence the parameter multiplier is not optimized. If you want to use 2 to generate data and 7 as init value, you can define a global variable mval:
def model():
multiplier = pyro.param('multiplier', mval)`
...
mval = torch.tensor(2.)
y = model(x)
pyro.clear_param_store()
mval = torch.tensor(7.)
... # svi
Another way is
multiplier = pyro.param('multiplier', lambda: torch.tensor(2.)
if y is None else torch.tensor(7.))`
Can I put pyro.param in a model? I don’t think so.
I get the following error in such a setting:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I do not care much about the initial value in the guide. In a real world scenario I would initialize it randomly. Here I wanted to initialize it with specific value to track the optimization process.
So how it is possible that it worked previously when multiplier has been applied in scale parameter? The answer is not crucial for further applications of deterministic.