Dear Pyro community,
I was playing around with pyro trying to implement inference over a categorical variable when I noticed strange convergence behavior in the posterior estimates (e.g. ELBO increases over time instead of decreasing). I am providing here a simplified version of the code with the hope that someone can tell me what is going on under the hood, or if I am doing something wrong.
So lets look first at the version of the code which behaves as expected:
n1 = 2
n2 = 3
n = 100
w = ones(n1,n2)
w[n1//2:,-1] = 0.
w /= w.sum(dim=-1)[:,None]
def model():
sample('depth', dist.Categorical(w).independent(1))
def guide():
probs = param('probs', ones(n1,n2)/n2, constraint=constraints.simplex)
sample('depth', dist.Categorical(probs=probs).independent(1))
I use the following code to do the inference and estimat values of ‘probs’ variable
clear_param_store()
num_particles = 10
n_iterations = 2000
svi = SVI(model=model,
guide=guide,
optim=Adam({'lr':0.1}),
loss=Trace_ELBO(num_particles=num_particles))
losses = []
with tqdm(total = n_iterations, file = sys.stdout) as pbar:
for step in range(n_iterations):
losses.append(svi.step())
pbar.set_description("Mean ELBO %6.2f" % torch.Tensor(losses[-20:]).mean())
pbar.update(1)
results = {}
for name, value in get_param_store().named_parameters():
results[name] = value
In this case I get near zero loss values, and almost perfect estimate of the ‘probs’ values
print(softmax(results['probs'], dim=-1))
tensor([[ 0.3333, 0.3333, 0.3333],
[ 0.5000, 0.5000, 0.0001]])
However, a slight change in the model
def model():
sample('depth', dist.Categorical(w).independent(1))
with iarange('data', size = n):
sample('obs', dist.Categorical(probs = ones(n,2)/2), obs = ones(n, dtype=torch.int64))
leads to very noisy estimates of the loss (it looks like it does not converge at all) and
very bad estimates of the ‘probs’ values
print(softmax(results['probs'], dim=-1))
tensor([[ 0.0193, 0.9756, 0.0051],
[ 0.9391, 0.0605, 0.0003]])
What is confusing for me is that adding an observable which is independent from the variable which I am interested in completely messes up the inference process. Looking carefully how loss changes over time I notice that it slightly increases with each new step in the inference.
I would appreciate if someone could explain what causes this interaction between independent variables (in my understanding ‘obs’ just adds a constant to the computation of the loss) and what could be the source of increased noise in computing the loss.