Gaussian Process Dynamical Model (GPDM). Indexing latent variables

Hi there! I am new to Pyro and I am trying to implement Gaussian Process Dynamical Models (GPDMs) (see for reference). Basically, a GPDM is a GPLVM with additional GPs governing the dynamics of the latent states. Given a sequence of N observations [y_1,…,y_N] (each y_i has dimension D) I want to find the associated latent sequence [x_1,…,x_N], (each x_i has dimension d<<D). The latent mapping x->y and the dynamic x(t)->x(t+1) are assumed to be governed by distinct GPs.

Following the GPLVM tutorial, I came up with this basic implementation. (Y is the N x D observation matrix and X_prior the initialization of the latent variables made with PCA).

# kernel of the latent map GP
kernelY = gp.kernels.RBF(input_dim=d, lengthscale=torch.ones(d))
# kernel of the dynamics GP
kernelX = gp.kernels.RBF(input_dim=d, lengthscale=torch.ones(d))

# initialize latent matrix N x D as parameter
X = Parameter(X_prior.clone()) 

# latent map GP X->Y
gpy = gp.models.GPRegression(X, Y.t(), kernelY, noise=torch.tensor(0.01), jitter=1e-5)
# dynamics GP X(1:N-1)->X(2:N)
gpx = gp.models.GPRegression(X[:-1,:], X[1:,:].t(), kernelX, noise=torch.tensor(0.01), jitter=1e-5)

optimizer = torch.optim.Adam(list(gpy.parameters()) + list(gpx.parameters()), lr=0.01)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
num_steps = 2500
for i in range(num_steps):
    loss = loss_fn(gpy.model, + loss_fn(gpx.model,

The code works fine, I can fit the models and find a reasonable latent evolution. But, If instead of slicing X like X[:-1,:] and X[1:,:], I use indices like

in_indices = list(range(0,N-1))
out_indices = list(range(1,N))
gpx = gp.models.GPRegression(X[in_indices,:], X[out_indices,:].t(), kernelX, noise=torch.tensor(0.01), jitter=1e-5)

I get the error

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

and if I use loss.backward(retain_graph=True) the optimization cannot converge to a good solution. I tried also to use Index and Vindex from pyro.ops.indexing but the problem persists.

Does anyone know the reason of the different effects obtained with slicing or indexing and how I should properly index X in this context?

Also, if someone has implemented GPDM using Pyro before, feel free to give me any kind of suggestion.

P.S. I need to index X to model two different sequences of observations of length N1 and N2, because I must ensure that the first latent state of the 2nd sequence does not depend on the last of the 1st sequence (in_indeces = list(range(0,N1-1))+list(range(N1,N1+N2-1)) and out_indices = list(range(1,N1))+list(range(N1+1,N1+N2))).