SVI for simple model with Dirichlet distribution

So I have a relatively simple model - including a Dirichlet distribution - that works fine with MLE and MAP estimation (AutoDelta guide) but fails to do anything useful with any other guide. I tried AutoNormal, AutoDiagNormal, and a manually specified guide. All yielded similar (wrong) results. Unfortunately, I need uncertainty quantification, so simply doing MAP estimation / AutoDelta is not an option.

Not fully reproducible (can add that if it helps), but the gist of what I am doing:

def model_IR(x, y=None):
    arr = x.detach().cpu().numpy()
    assert np.all(arr[:-1] <= arr[1:]), "x input is assumed to be sorted"
    if y is not None:
        assert x.shape == y.shape

    steps = pyro.sample("steps", dist.Dirichlet(torch.ones_like(x)))  # flat prior over vector with positive entries that sums to one
    scale = pyro.sample("scale", dist.Beta(1, 1))  # flat prior over [0, 1] variable
    # probs is a (latent) monotonically increasing vector with probs[0] >= 0 and probs[-1] < 1.
    probs = torch.minimum(steps.cumsum(-1) * scale, torch.tensor(1 - 1e-7))  # I do not know why the clipping is necessary

    with pyro.plate("data", x.shape[0]):
        # observed binary outcomes, p(y=1) given by probs
        observation = pyro.sample("obs", dist.Bernoulli(probs=probs), obs=y)

def guide_IR(x, y=None):
    dirichlet_pars = pyro.param("dirichlet_pars", torch.ones_like(x), constraint=dist.constraints.positive)
    steps = pyro.sample("steps", dist.Dirichlet(dirichlet_pars))

    beta_alpha = pyro.param("beta_alpha", torch.tensor(1.0), constraint=dist.constraints.positive)
    beta_beta = pyro.param("beta_beta", torch.tensor(1.0), constraint=dist.constraints.positive)
    scale = pyro.sample("scale", dist.Beta(beta_alpha, beta_beta))

# SVI estimation, runs through but doesn't yield anywhere near correct results.
# I also tried guide_IR above and an AutoDiagNormal guide, all yielding similar (wrong) results.
# When I use AutoDelta(model_IR) here, I get the same (correct) result as with ML estimation.
guide = AutoNormal(model_IR)  

scheduler = pyro.optim.MultiStepLR({'optimizer': torch.optim.Adam,
                         'optim_args': {'lr': 0.005},
                         'milestones': [20],
                         'gamma': 0.2})
svi = SVI(model=model_IR, guide=guide, optim=scheduler, loss=Trace_ELBO())
for i in range(5000):
    loss = svi.step(X_train, y_train, mle=mle)

predictive = Predictive(model_IR, guide=guide, num_samples=100)
svi_samples = predictive(X_train)
probs = svi_samples['scale'].squeeze().unsqueeze(-1) * svi_samples['steps'].squeeze().cumsum(-1)

Without going into the details, this is what a correct solution of variable “probs” on y axis plotted vs. x would look like in my test case (as correctly identified when doing MLE/MAP estimation): 0 in the left half, jumps to 1 in the right half (blue line; the black dashed line is just for reference).

This is what’s estimated by SVI instead (in this case, using the AutoNormal guide): close to a diagonal.

Any ideas on how I might get closer to a correct solution are highly appreciated!

I believe I already do most if not all applicable things mentioned in the Tips and Tricks section of the SVI tutorial.

Am I correct in assuming that the model / prior cannot be the problem because it works with an AutoDelta guide?

Is it possible to evaluate the loss at the true solution and see whether I’m just seeing a local minimum and need to play around more with the optimization strategy? (Results were at least robust to different learning rates, LR scheduling yes/no, and different initialization attempts. So I am not overly optimistic about this.)

One thing that thoroughly confuses me is that when I perform inference using MCMC (NUTS, 4 chains, 1000 warmup steps + 1000 samples) I also receive results like in the second figure above. But MCMC is fully independent of the guide… so maybe the problem is not with the guide, after all? But how can the model/prior be at fault if it works nicely with the AutoDelta guide / MAP? Maybe the problem is somehow with how I use predictive in the end…? I am basically just interested in the distribution of the latent “probs” variable.

if N=x.shape[0] then you have N binary observations and N+1 latent variables. the true posterior is presumably going to be quite broad and scale and steps are going to be highly correlated.

how are your producing x? it may just be that your expectations are wrong. i don’t see why should expect to see a sharp step function in the inferred probs

Hi @martinjankowiak, my example is created like this:

x = np.linspace(0, 1, 501)
y = np.zeros_like(x)
y[round(len(y)/2):] = 1

So I have only “0” observations in the left half and only “1” observations in the right half, with x equally spaced (and ordered).

For context, I am trying to implement a Bayesian version of Isotonic Regression (hence the “IR”), restricted to [0,1] in both x and y]

Standard IR (using sklearn.isotonic.IsotonicRegression) yields the expected solution, as does the ML/MAP solution using the model above. (This is expected since classical IR coincides with the ML solution and my priors are all flat.)

well you should probably change your prior then because what you’re doing isn’t proper regression since it doesn’t depend on some input x

@martinjankowiak Ahh sorry, I should have explained that before; I see why it looks weird but I am relatively certain that it is fine. The model is learned and will only work for this particular sequence of x values, which is exactly what I need. I will never apply it to any other x values. It is also correct that the values of x are irrelevant - it is just assumed that they are the same as during training. I could actually drop the model’s dependence on x, I think. Alternatively, I could probably make it a “proper regression model” by doing some kind of interpolation scheme (which is also what people do when they want to apply a standard IR model to other x values), but I think that would only unnecessarily complicate the model. (And again, MAP estimation works fine…)

if you think e.g. that there is a single change point you should encode that directly into your model assumptions

I don’t know that, unfortunately. The curve produced in the second image could actually also be the correct solution, it just isn’t in this particular example.

And - sorry to reiterate this - doesn’t the fact that everything works fine with an AutoDelta guide show that the prior is not the issue / at least the mode of the variationally learned distribution should be correct?

“problem” is not well defined especially if its measured with respect to some expectations you’re bring to the table, expectations which may be wrong.

the fact that the SVI and NUTS agree would suggest that SVI is doing an ok job.

i suggest you draw 10 samples of probs from your prior and take a look at them. they will not look like step functions

when you observe only N 0/1 bits that’s not a lot of information and a lot of solutions of probs will be consistent with that information. as such the posterior should be broad