Multilevel logistic regression ("ValueError: can't optimize a non-leaf Tensor")

Multilevel mode, hierarchical model, mixed effect model.

I have implemented a multilevel logistic regression model in stan, but since I am running it on a large dataset, computational speed is an issue. Therefore I have turned to pyro. I managed to implement a single level logistic regression in pyro, but once I add hyperpriors to the slopes in the model I get the following error message:

“ValueError: can’t optimize a non-leaf Tensor”

I provide the model and the guide for my model below:

def model(x, y):

    n = x.shape[1]
    # Hyper priors
    mu = pyro.sample("mu", Normal(loc = torch.zeros((1, n)),
                                            scale = torch.ones((1, n))))
    sigma = pyro.sample("sigma", HalfCauchy(scale = torch.ones((1, n))))

    # Priors
    a = pyro.sample("a", Normal(0, 1))
    b = pyro.sample("b", Normal(mu, sigma))

    # Linear model. Output: Logits (l_est)
    l_est = torch.mm(x, b.t()) + a
    return pyro.sample("y_est", Bernoulli(logits = l_est), obs=y)

def guide(x, y):

    n = x.shape[1]

    # Hyper priors for b_mu
    b_mu_mu_hyp = torch.rand((1, n))
    b_mu_sigma_hyp = torch.rand((1, n))
    b_mu_mu = pyro.param("b_mu_mu", b_mu_mu_hyp)
    b_mu_sigma = F.softplus( pyro.param("b_mu_sigma", b_mu_sigma_hyp) )
    b_mu = pyro.sample("b_mu", Normal(b_mu_mu,
                                      b_mu_sigma))

    # Hyper priors for b_sigma
    b_sigma_sigma_hyp = torch.rand((1, n))
    b_sigma_sigma = F.softplus( pyro.param("b_sigma_sigma", b_sigma_sigma_hyp) )
    b_sigma = pyro.sample("b_sigma", HalfCauchy(b_sigma_sigma))

    # Priors for b
    b_mu_param = pyro.param("b_mu_param", b_mu)
    b_sigma_param = F.softplus( pyro.param("b_sigma_param", b_sigma) )
    b = pyro.sample("b_est", Normal(loc = b_mu_param,
                                    scale = b_sigma_param))

    # Priors for a
    a_mu = torch.rand((1, 1))
    a_sigma = torch.rand((1, 1))
    a_mu_param = pyro.param("a_mu", a_mu)
    a_sigma_param = F.softplus( pyro.param("a_sigma", a_sigma) )
    a = pyro.sample("a_est", Normal(loc = a_mu_param,
                                    scale = a_sigma_param))

    #Linear model. Output: Logits
    return torch.mm(x, b.t()) + a

x_train = train[:][0]
y_train = train[:][1]

# Training the model
svi = SVI(model,
          guide,
          Adam({"lr": .005}),
          loss = Trace_ELBO())

print(x_train.shape, y_train.shape)

for i in range(1000):
    svi.step(x_train, y_train)
    if i % 100 == 0:
        print('.', end='')

The dimensions of x_train & y_train are [10, 105] & [10, 1] respectively. I have 10 datapoints with 105 variables (this dataset is just for testing out the code. I am not asking any questinos regarding inference).

How do I get the code to run?

can you please provide a more complete code snippet, in particular something that includes your invocation of pyro.infer.SVI, pyro.optim, etc.?

Of course. It is included now.

Hi @KasperFischer,

to start with a disclaimer: I’m also new to pyro.
My guess is that the confusion is in your guide().
Inside your guide, you want to simply sample from your approximate posterior .
This means that you want to sample from the same sites that are appearing in your model() as a prior (those on which you don’t conditon/observe).
A site in this context is the named thing that appears in a sample() statement.
In your case this means that sample("mu", ...), sample("sigma", ...), sample("a", ...), sample("b", ...) should all also appear in your guide().

Also this

is not really necessary in your guide.
Note, that the return value of your guide is not really used during inference (of course you can still return something if you want to call guide() yourself, but it’s not necessary for inference).
It clicked for me when I realized that all that matters for inference in model() and guide() are the side effects that sample(), plate(), etc. statements have when they’re called. The return value of the function does not matter at all.
I think this is something that is not yet obvious in the documentation/tutorials for people coming from a non-PPL background.

Hope that helps.
Also hope more experienced users/devs will correct any inexact or plainly wrong statements that I made.

1 Like

I think this post may help you.
Toy Example of Bayesian Logistic Regression

1 Like