-
-
Notifications
You must be signed in to change notification settings - Fork 161
Problems with custom loss functions #444
cyhltsrjyyc
started this conversation in
General
-
I want to use a custom loss function which needs other inputs, the loss function is showed as follow:
class CustomLoss(nn.Module):
def init(self, teacher_preds_2, teacher_preds_3, teacher_preds_4):
super(CustomLoss, self).init()
self.teacher_preds_2 = torch.tensor(teacher_preds_2, requires_grad=True)
self.teacher_preds_3 = torch.tensor(teacher_preds_3, requires_grad=True)
self.teacher_preds_4 = torch.tensor(teacher_preds_4, requires_grad=True)
def forward(self, y_true, y_pred):
y_true = torch.tensor(y_true, requires_grad=True)
y_pred = torch.tensor(y_pred, requires_grad=True)
main_loss = torch.mean((y_true - y_pred) ** 2)
distill_loss_2 = sup.get_R2(np.array(y_pred).reshape(1, -1), np.array(self.teacher_preds_2).reshape(1, -1))
distill_loss_3 = sup.get_R2(np.array(y_pred).reshape(1, -1), np.array(self.teacher_preds_3).reshape(1, -1))
distill_loss_4 = sup.get_R2(np.array(y_pred).reshape(1, -1), np.array(self.teacher_preds_4).reshape(1, -1))
distill_loss = (distill_loss_2 + distill_loss_3 * 2 + distill_loss_4 * 7) / 10
total_loss = main_loss + 0.01 * distill_loss
return total_loss
The "teacher_preds_2, teacher_preds_3, teacher_preds_4" are three fixed tensors with the same size as the training sample. How can I get the batch of the three tensors and calculate the R2 between them and y_pred during fit process?
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment