Custom Loss Function Implementation

Hi,

I am trying to implement a custom loss function with has two different parts, but I am not sure how to put them together.
→ The first part of the loss function is the ELBO, for which I want to use loss = pyro.infer.Trace_ELBO().
→ The second part is a l2 regularizer described as:

def L2_regularizer(my_parameters, lam=torch.tensor(1.)):
    reg_loss = 0.0
    for param in my_parameters:
        reg_loss = reg_loss + param.pow(2.0).sum()
    return lam*reg_loss

My model and guide look as follows:

def model(data):
    x_loc = torch.zeros(N*3,)
    x_scale = 2*torch.ones(N*3,)
    x = pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))
    ....
    ....
def guide(data):
    x_loc = pyro.param("x_loc", torch.rand(N*3,))
    x_scale = pyro.param("x_scale", 0.5*torch.ones(N*3,), constraint=constraints.positive)
    x = pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))

The parameters that I want to use in the L2_regularizer are the x_loc (which I define inside the guide).

How can put these two loss functions together? Any help would be appreciated.

Thank you,
Atharva

One way to combine these is to use pyro.factor in your model, using the negative loss (i.e. treat the loss as a negative log likelihood term):

EDIT this is wrong, see discussion below :point_down:

  def model(data):
      x_loc = torch.zeros(N*3,)
+     pyro.factor("regularizer", -L2_regularizer(x_loc))
      x_scale = 2*torch.ones(N*3,)
      x = pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))
      ....
      ....

Another way would be to define a custom objective function, but I think the pyro.factor method is cleaner.

1 Like

@fritzo thanks for the quick reply. I have three follow up questions.

  1. Shouldn’t the pyro.factor("regularizer", -L2_regularizer(x_loc)) statement be inside the guide instead of the model? I am saying this because -L2_regularizer(x_loc) inside the model will always return zero (since x_loc = torch.zeros(N*3,)).

  2. If I want to add another part to this loss, let say a loss function L3_func(...), can I simply use the pyro.factor as pyro.factor("regularizer2", -L3_func(...))?

  3. I am having some trouble trying to define a custom objective function. Based on the custom objective function documentation, I understand that I can define the total loss as:

loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
loss = loss_fn(model, guide) + L2_regularizer(my_parameters).

What I am not sure is where/how to define my_parameters and how to pass the data through the model and guide. For my case, in the guide, I have two sets of parameters, namely, x_loc and x_scale. Do I define them outside the guide? as follows,

x_loc = nn.Parameter(torch.rand(N*3))
x_scale = nn.Parameter(0.5*torch.ones(N*3,))

def guide(data):
    x_loc = x_loc
    x_scale = x_scale
    x = pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))
       

If you can share an example or guide me how to set up the custom objective that would be great.

  1. Gosh you’re right, my model regularizes the wrong x_loc :blush: . I guess this is a little trickier in the guide. I think we’ll need to negate it (again, since the loss is \mathbb E_q [\log q - \log p] and we’re moving the factor from the model p to the guide q) and specify has_rsample=True. Does this work for you?
  def guide(data):
      x_loc = pyro.param("x_loc", torch.rand(N*3,))
      x_scale = pyro.param("x_scale", 0.5*torch.ones(N*3,),
                           constraint=constraints.positive)
+     pyro.factor("regularizer", L2_regularizer(x_loc), has_rsample=True)
      pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))
  1. Yes you can use multiple pyro.factor statements to add multiple regularizers.
  2. If yo want to define a custom and still use SVI (rather than the lower-level interface in the tutorial), I think you can define a custom loss function
elbo_loss_fn = Trace_ELBO().differentiable_loss

def loss_fn(data):
    elbo_loss = elbo_loss_fn(model, guide, data)
    x_loc = pyro.param("x_loc")
    reg_loss = L2_regularizer(x_loc)
    return elbo_loss + reg_loss

Furthermore, I believe (1) and (3) should be equivalent.

1 Like
  1. Setting has_rsample=True returns the following error:
    TypeError: factor() got an unexpected keyword argument 'has_rsample'.

  2. For the custom objective function, I have the following setup:

def model(data):
    x_loc = torch.zeros(N*3,)
    x_scale = 2*torch.ones(N*3,)
    x = pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))
    ....
def guide(data):
    x_loc = pyro.param("x_loc", torch.rand(N*3,))
    x_scale = pyro.param("x_scale", 0.5*torch.ones(N*3,), constraint=constraints.positive)
    x = pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))
elbo_loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

def loss_fn(data):
    elbo_loss = elbo_loss_fn(model, guide, data)
    x_loc = pyro.param("x_loc")
    reg_loss = L2_regularizer(x_loc)
    return elbo_loss + reg_loss
# optimizer
optimizer = torch.optim.Adam(my_parameters, {"lr": 0.001, "betas": (0.90, 0.999)})
for i in range(num_steps):
    loss = loss_fn(data=data_obs)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

What I am confused about is how to get the my_parameters in optimizer defined above? For my case, the parameters are defined in the guide x_loc and x_scale.

I think you’ll need to update Pyro. The older Pyro without support for pyro.factor(..., has_rsample=True) will result in incorrect inference.

I think you should be able to use

my_parameters = list(guide.parameters())

@fritzo,

  1. setting has_rsample=True did the job.

  2. Using,

my_parameters = list(guide.parameters())

displays the following error:

AttributeError: 'function' object has no attribute 'parameters'

1 Like