Ideally, in my undestanding of Pyro, it would be something like:
#import statements
p_z_x = #fancy neural network
def model(X, y):
pyro.plate('data'):
z = pyro.sample('z', Normal(mu, sigs).to_event())
#other stuff like
# outputs = p_x_z(z) or something
def guide(X, y):
# do nothing
#build your optimzier, svi etc etc.
Now, I know the above does not work and I get a key error doing this silly billy business.
Instead, I tried the following much more hacky solution
#import statements
p_z_x = #fancy neural network
def model(X, y):
mu, sigs = p_z_x(X)
pyro.plate('data'):
z = pyro.sample('z', Normal(mu, sigs))
#other stuff like
# outputs = p_x_z(z) ,etc etc,
def guide(X, y):
#just produce parameters again, maybe detach this time or something
mu, sigs = p_z_x(X) #generate samples from my "variational" distr
pyro.plate('data'):
z = pyro.sample('z', Normal(mu, sigs))
#build your optimzier, svi etc etc.
It seems to do the trick, but I’m wondering if there’s a better way to go about this? Perhaps it seems like a weird request, but to reimplement a paper by what they ACTUALLY do I need to be able to do something along these lines. I think I’m basically asking the same thing as the passing variables post but it’d be good just to clarify. Would I basically instead call my sampling in guide then use a delta function in model instead?