Bayesian linear regression with custom layer

Dear community,
I’m learning Pyro but I do not understand well how to create new layers such as this I share here. Here, I would like to model seasonality with a module I create and I have 3 regressors and a predicted continuous variable (y2)

import torch
from torch import nn
from pyro.nn import PyroModule
import pyro
import pyro.distributions as dist

from pyro.nn import PyroSample
from pyro.infer.autoguide import AutoDiagonalNormal

class seasonality(nn.Module):
    def __init__(self, T):
        shift = torch.randn(1)
        self.shift = nn.Parameter(shift)
        self.pi = torch.acos(torch.zeros(1)).item() * 2
        self.T = T

    def forward(self, x):
        out = torch.sin(2*self.pi/self.T * (x + self.shift))
        return out

class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
        self.season = PyroModule[seasonality](365)
        self.season.shift = PyroSample(dist.Normal(0., 1).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        x[:, -1] = self.season(x[:, -1])
        mean = (self.linear(x)).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean
model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)

from pyro.infer import SVI, Trace_ELBO

adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

X2 = torch.from_numpy(X).float()
y2 = torch.from_numpy(y).float()

for j in range(500):
    # calculate the loss and take a gradient step
    loss = svi.step(X2, y2)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(X2)))

However, I have this error:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.

I think it’s because I replace x[:, -1] but how can I do it?

i think the problem may be that shift is a Parameter but you’re not taking gradients wrt shift. do you intend to learn shift? does it work if you make shift a fixed tensor? if you want to learn shift you probably need to make the seasonality module a PyroModule