I have two dependent variables x and y
x = pyro.sample('x', torch.distributions.Normal(2, 5))
y = pyro.sample('y', torch.distributions.Normal(3, 4))
I would like to sample from y based on the dependency with x. How do I do this exactly in pyro?
I have two dependent variables x and y
x = pyro.sample('x', torch.distributions.Normal(2, 5))
y = pyro.sample('y', torch.distributions.Normal(3, 4))
I would like to sample from y based on the dependency with x. How do I do this exactly in pyro?
Hi @ebtudpy,
First note that you should use Pyro’s distributions (which wrap PyTorch’s) rather than raw PyTorch distributions when using Pyro
- torch.distributions.Normal(2, 5)
+ pyro.distributions.Normal(2, 5)
Now if you’d like y
to depend on x
you’ll just need to make the parameters of y
depend somehow on x
it’s up to you how. Here’s an example where y
's prior mean is equal to x
:
import pyro.distributions as dist
x = pyro.sample('x', dist.Normal(2, 5))
y = pyro.sample('y', dist.Normal(x, 4))
Here’s an example where both the location (mean) and scale (standard deviation) parameter of y
depend on a complicated function of x
.
f = torch.nn.linear(1, 2)
pyro.module("f", f)
x = pyro.sample('x', dist.Normal(2, 5))
loc, log_scale = f(x.unsqueeze(-1)).unbind(-1)
scale = log_scale.exp()
y = pyro.sample('y', dist.Normal(loc, scale))
Good luck!