First, I find initialization to be quite important in SVI. You might try initializing p_s_loc
to data and p_s_scale
to either overestimate or underestimate variance
pyro.param("p_s_loc", torch.abs(data_vec).log1p()) # or similar
pyro.param("p_s_scale", 0.1 * torch.ones(N),
constraint=constraints.positive)
or maybe
pyro.param("p_s_scale", 10.0 * torch.ones(N),
constraint=constraints.positive)
Second, I believe your guide can be automatically constructed via
guide = pyro.contrib.autoguide.AutoDiagonalNormal(model)
though you would need to interact with it a bit differently.