i think model_logits
has shape (500, 1)
and it is broadcasting. try adding:
model_logits = lifted_reg_model(x_data).squeeze(-1)
i think model_logits
has shape (500, 1)
and it is broadcasting. try adding:
model_logits = lifted_reg_model(x_data).squeeze(-1)