Making trainable distribution using 'pytorch' and 'pyro.distributions' .. very beginner question

I’m trying to train ‘trainable Bernoulli distribution’ using ‘pyro.distributions’.

I want to train Bernoulli distribution’s parameter(probability to win) using NLL loss.

train_data is one-hot encoded sparse matrix(2034,19475) and train_labels has 4 value(4 class, [0,1,2,3]).

import torch
import pyro
pyd = pyro.distributions

print("torch version:", torch.__version__)
print("pyro version:", pyro.__version__)

import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(123)


### 0. define Negative Log Likelihood(NLL) loss function
def nll(x_train, distribution):    
    return -torch.mean(distribution.log_prob(torch.tensor(x_train, dtype=torch.float)))


### 1. initialize bernoulli distribution(trainable distribution)
train_vars = (pyd.Uniform(low=torch.FloatTensor([0.01]),
                          high=torch.FloatTensor([0.1])).rsample([train_data.shape[-1]]).squeeze())
distribution = pyd.Bernoulli(probs=train_vars)

### 2. initialize 'label 0' data
class_mask = (train_labels==0)
class_data = train_data[class_mask, :]

### 3. initialize optimizer
optim = torch.optim.Adam([train_vars])

train_vars.requires_grad=True

### 4. train loop
for i in range(0,100):
    
    loss = nll(class_data, distribution)
    
    loss.backward()

When I run this code, I get RUNTIME ERROR like below…

How should I deal with this error case?

Is there any thing that I should consider?

As a beginner with pyro, your comment would be very very very appreciate…

torch version: 1.9.0+cu102
pyro version: 1.7.0
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-269-0081bb1bb843> in <module>
     25     loss = nll(class_data, distribution)
     26 
---> 27     loss.backward()
     28 

/nf/yes/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    253                 create_graph=create_graph,
    254                 inputs=inputs)
--> 255         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    256 
    257     def register_hook(self, hook):

/nf/yes/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145         retain_graph = create_graph
    146 
--> 147     Variable._execution_engine.run_backward(
    148         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    149         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.