Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 69b384d

Browse files
author
Stefan
committed
Initial commit
1 parent 6b93e66 commit 69b384d

File tree

5 files changed

+725
-0
lines changed

5 files changed

+725
-0
lines changed

‎losses.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
2+
from __future__ import print_function, division
3+
4+
import torch
5+
from torch.autograd import Variable
6+
import torch.nn.functional as F
7+
import numpy as np
8+
try:
9+
from itertools import ifilterfalse
10+
except ImportError: # py3k
11+
from itertools import filterfalse as ifilterfalse
12+
13+
def dice_loss(pred, target):
14+
"""This definition generalize to real valued pred and target vector.
15+
This should be differentiable.
16+
pred: tensor with first dimension as batch
17+
target: tensor with first dimension as batch
18+
"""
19+
20+
smooth = 1.
21+
22+
# have to use contiguous since they may from a torch.view op
23+
iflat = pred.contiguous().view(-1)
24+
tflat = target.contiguous().view(-1)
25+
intersection = (iflat * tflat).sum()
26+
27+
A_sum = torch.sum(tflat * iflat)
28+
B_sum = torch.sum(tflat * tflat)
29+
30+
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )
31+
32+
def dice_loss2(input,target):
33+
input = torch.sigmoid(input)
34+
35+
smooth = 1.
36+
37+
iflat = input.view(-1)
38+
tflat = target.view(-1)
39+
intersection = (iflat * tflat).sum()
40+
41+
return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth))
42+
43+
"""
44+
Lovasz-Softmax and Jaccard hinge loss in PyTorch
45+
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
46+
"""
47+
48+
def lovasz_grad(gt_sorted):
49+
"""
50+
Computes gradient of the Lovasz extension w.r.t sorted errors
51+
See Alg. 1 in paper
52+
"""
53+
p = len(gt_sorted)
54+
gts = gt_sorted.sum()
55+
intersection = gts - gt_sorted.float().cumsum(0)
56+
union = gts + (1 - gt_sorted).float().cumsum(0)
57+
jaccard = 1. - intersection / union
58+
if p > 1: # cover 1-pixel case
59+
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
60+
return jaccard
61+
62+
63+
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
64+
"""
65+
IoU for foreground class
66+
binary: 1 foreground, 0 background
67+
"""
68+
if not per_image:
69+
preds, labels = (preds,), (labels,)
70+
ious = []
71+
for pred, label in zip(preds, labels):
72+
intersection = ((label == 1) & (pred == 1)).sum()
73+
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
74+
if not union:
75+
iou = EMPTY
76+
else:
77+
iou = float(intersection) / union
78+
ious.append(iou)
79+
iou = mean(ious) # mean accross images if per_image
80+
return 100 * iou
81+
82+
83+
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
84+
"""
85+
Array of IoU for each (non ignored) class
86+
"""
87+
if not per_image:
88+
preds, labels = (preds,), (labels,)
89+
ious = []
90+
for pred, label in zip(preds, labels):
91+
iou = []
92+
for i in range(C):
93+
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
94+
intersection = ((label == i) & (pred == i)).sum()
95+
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
96+
if not union:
97+
iou.append(EMPTY)
98+
else:
99+
iou.append(float(intersection) / union)
100+
ious.append(iou)
101+
ious = map(mean, zip(*ious)) # mean accross images if per_image
102+
return 100 * np.array(ious)
103+
104+
105+
# --------------------------- BINARY LOSSES ---------------------------
106+
107+
108+
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
109+
"""
110+
Binary Lovasz hinge loss
111+
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
112+
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
113+
per_image: compute the loss per image instead of per batch
114+
ignore: void class id
115+
"""
116+
if per_image:
117+
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
118+
for log, lab in zip(logits, labels))
119+
else:
120+
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
121+
return loss
122+
123+
124+
def lovasz_hinge_flat(logits, labels):
125+
"""
126+
Binary Lovasz hinge loss
127+
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
128+
labels: [P] Tensor, binary ground truth labels (0 or 1)
129+
ignore: label to ignore
130+
"""
131+
if len(labels) == 0:
132+
# only void pixels, the gradients should be 0
133+
return logits.sum() * 0.
134+
signs = 2. * labels.float() - 1.
135+
errors = (1. - logits * Variable(signs))
136+
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
137+
perm = perm.data
138+
gt_sorted = labels[perm]
139+
grad = lovasz_grad(gt_sorted)
140+
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
141+
return loss
142+
143+
144+
def flatten_binary_scores(scores, labels, ignore=None):
145+
"""
146+
Flattens predictions in the batch (binary case)
147+
Remove labels equal to 'ignore'
148+
"""
149+
scores = scores.view(-1)
150+
labels = labels.view(-1)
151+
if ignore is None:
152+
return scores, labels
153+
valid = (labels != ignore)
154+
vscores = scores[valid]
155+
vlabels = labels[valid]
156+
return vscores, vlabels
157+
158+
159+
class StableBCELoss(torch.nn.modules.Module):
160+
def __init__(self):
161+
super(StableBCELoss, self).__init__()
162+
def forward(self, input, target):
163+
neg_abs = - input.abs()
164+
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
165+
return loss.mean()
166+
167+
168+
def binary_xloss(logits, labels, ignore=None):
169+
"""
170+
Binary Cross entropy loss
171+
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
172+
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
173+
ignore: void class id
174+
"""
175+
logits, labels = flatten_binary_scores(logits, labels, ignore)
176+
loss = StableBCELoss()(logits, Variable(labels.float()))
177+
return loss
178+
179+
180+
# --------------------------- MULTICLASS LOSSES ---------------------------
181+
182+
183+
def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None):
184+
"""
185+
Multi-class Lovasz-Softmax loss
186+
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
187+
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
188+
only_present: average only on classes present in ground truth
189+
per_image: compute the loss per image instead of per batch
190+
ignore: void class labels
191+
"""
192+
if per_image:
193+
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present)
194+
for prob, lab in zip(probas, labels))
195+
else:
196+
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present)
197+
return loss
198+
199+
200+
def lovasz_softmax_flat(probas, labels, only_present=False):
201+
"""
202+
Multi-class Lovasz-Softmax loss
203+
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
204+
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
205+
only_present: average only on classes present in ground truth
206+
"""
207+
if probas.numel() == 0:
208+
# only void pixels, the gradients should be 0
209+
return probas * 0.
210+
C = probas.size(1)
211+
212+
C = probas.size(1)
213+
losses = []
214+
for c in range(C):
215+
fg = (labels == c).float() # foreground for class c
216+
if only_present and fg.sum() == 0:
217+
continue
218+
errors = (Variable(fg) - probas[:, c]).abs()
219+
errors_sorted, perm = torch.sort(errors, 0, descending=True)
220+
perm = perm.data
221+
fg_sorted = fg[perm]
222+
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
223+
return mean(losses)
224+
225+
226+
def flatten_probas(probas, labels, ignore=None):
227+
"""
228+
Flattens predictions in the batch
229+
"""
230+
B, C, H, W = probas.size()
231+
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
232+
labels = labels.view(-1)
233+
if ignore is None:
234+
return probas, labels
235+
valid = (labels != ignore)
236+
vprobas = probas[valid.nonzero().squeeze()]
237+
vlabels = labels[valid]
238+
return vprobas, vlabels
239+
240+
def xloss(logits, labels, ignore=None):
241+
"""
242+
Cross entropy loss
243+
"""
244+
return F.cross_entropy(logits, Variable(labels), ignore_index=255)
245+
246+
247+
# --------------------------- HELPER FUNCTIONS ---------------------------
248+
def isnan(x):
249+
return x != x
250+
251+
252+
def mean(l, ignore_nan=True, empty=0):
253+
"""
254+
nanmean compatible with generators.
255+
"""
256+
l = iter(l)
257+
if ignore_nan:
258+
l = ifilterfalse(isnan, l)
259+
try:
260+
n = 1
261+
acc = next(l)
262+
except StopIteration:
263+
if empty == 'raise':
264+
raise ValueError('Empty mean')
265+
return empty
266+
for n, v in enumerate(l, 2):
267+
acc += v
268+
if n == 1:
269+
return acc
270+
return acc / n

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /