Assume I have a model x = S * z, where both S and z are estimated using SVI. S is a model parameter and z is my latent variable following a normal distribution.
Now, I would like to get the most likely z_i for an observation x_i. How can I do that with internal Pyro functions?
def model(data):
s = pyro.param("s", torch.randn(G, K))
z0_loc = pyro.param("z0_loc", torch.zeros(K))
z0_scale = pyro.param("z0_scale", torch.eye(K, K), constraint=constraints.positive_definite)
z0 = pyro.sample("z0", dist.MultivariateNormal(z0_loc, z0_scale))
sigma0 = pyro.param("sigma0", torch.eye(G, G), constraint=constraints.positive_definite)
mean = s @ z0
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.MultivariateNormal(mean, sigma0), obs=data.T)
guide = pyro.infer.autoguide.AutoNormal(model)