Enquiry About the Derivative of a None-Leaf Tensor in Pyro Framework


I have a very basic linear model, which is setup as follows:

def model(is_cont_africa,ruggedness,log_gdp=None):

    with pyro.plate('data'):
        return pyro.sample('obs',dists.Normal(mean,sigma),obs=log_gdp)

When I want to use loss=model(is_cont_africa,ruggedness,log_gdp) to compute the derivatives of the latent five parameters via loss.backward().

It prompts sigma is not a leaf tensor, I have already tried the traditional way in PyTorch, such as requires_grad_() or retain_grad(). I still cannot make sigma as a leaf tensor.

Could someone tell me how to relatively easily return the log prob, which could compute every parameter’s gradient, which might be used for the latter Hessian matrix computation.

The following are some errors, if I use the above model:

RuntimeError Traceback (most recent call last)
Cell In[189], line 4
1 loss=model(is_cont_africa,ruggedness,log_gdp)
2 loss=loss.sum()
----> 4 loss.backward()

File /opt/conda/lib/python3.10/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
482 if has_torch_function_unary(self):
483 return handle_torch_function(
484 Tensor.backward,
485 (self,),
490 inputs=inputs,
491 )
→ 492 torch.autograd.backward(
493 self, gradient, retain_graph, create_graph, inputs=inputs
494 )

File /opt/conda/lib/python3.10/site-packages/torch/autograd/init.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
246 retain_graph = create_graph
248 # The reason we repeat the same comment below is that
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
→ 251 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors
254 retain_graph,
255 create_graph,
256 inputs,
257 allow_unreachable=True,
258 accumulate_grad=True,
259 )

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn