I am trying to implement a custom distribution (double truncated exponential) using the TorchDistribution class, but I wanted to make sure I understand the basics first. What is the difference between sample and rsample methods?
oops, sorry! just saw stackoverflow.
You should try to implement
.rsample() which stands for “reparametrized sample”, i.e. a stochastically differentiable sample operator. If you implement only
.sample() will be defined automatically. If you only implement
.sample() then stochastic gradient estimation will generally be higher variance and inference will be slower. If you implement
.rsample() you should also set
has_rsample = True.
Please add a link to the stackoverflow post you found.