Custom Loss Function Implementation

  1. Gosh you’re right, my model regularizes the wrong x_loc :blush: . I guess this is a little trickier in the guide. I think we’ll need to negate it (again, since the loss is \mathbb E_q [\log q - \log p] and we’re moving the factor from the model p to the guide q) and specify has_rsample=True. Does this work for you?
  def guide(data):
      x_loc = pyro.param("x_loc", torch.rand(N*3,))
      x_scale = pyro.param("x_scale", 0.5*torch.ones(N*3,),
                           constraint=constraints.positive)
+     pyro.factor("regularizer", L2_regularizer(x_loc), has_rsample=True)
      pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))
  1. Yes you can use multiple pyro.factor statements to add multiple regularizers.
  2. If yo want to define a custom and still use SVI (rather than the lower-level interface in the tutorial), I think you can define a custom loss function
elbo_loss_fn = Trace_ELBO().differentiable_loss

def loss_fn(data):
    elbo_loss = elbo_loss_fn(model, guide, data)
    x_loc = pyro.param("x_loc")
    reg_loss = L2_regularizer(x_loc)
    return elbo_loss + reg_loss

Furthermore, I believe (1) and (3) should be equivalent.

1 Like