I am trying to implement a custom distribution (double truncated exponential) using the TorchDistribution class, but I wanted to make sure I understand the basics first. What is the difference between sample and rsample methods?

oops, sorry! just saw stackoverflow.

You should try to implement `.rsample()`

which stands for “reparametrized sample”, i.e. a stochastically differentiable sample operator. If you implement only `.rsample()`

, then `.sample()`

will be defined automatically. If you only implement `.sample()`

then stochastic gradient estimation will generally be higher variance and inference will be slower. If you implement `.rsample()`

you should also set `has_rsample = True`

.

Please add a link to the stackoverflow post you found.

1 Like