1
+
2
+ from __future__ import print_function , division
3
+
4
+ import torch
5
+ from torch .autograd import Variable
6
+ import torch .nn .functional as F
7
+ import numpy as np
8
+ try :
9
+ from itertools import ifilterfalse
10
+ except ImportError : # py3k
11
+ from itertools import filterfalse as ifilterfalse
12
+
13
+ def dice_loss (pred , target ):
14
+ """This definition generalize to real valued pred and target vector.
15
+ This should be differentiable.
16
+ pred: tensor with first dimension as batch
17
+ target: tensor with first dimension as batch
18
+ """
19
+
20
+ smooth = 1.
21
+
22
+ # have to use contiguous since they may from a torch.view op
23
+ iflat = pred .contiguous ().view (- 1 )
24
+ tflat = target .contiguous ().view (- 1 )
25
+ intersection = (iflat * tflat ).sum ()
26
+
27
+ A_sum = torch .sum (tflat * iflat )
28
+ B_sum = torch .sum (tflat * tflat )
29
+
30
+ return 1 - ((2. * intersection + smooth ) / (A_sum + B_sum + smooth ) )
31
+
32
+ def dice_loss2 (input ,target ):
33
+ input = torch .sigmoid (input )
34
+
35
+ smooth = 1.
36
+
37
+ iflat = input .view (- 1 )
38
+ tflat = target .view (- 1 )
39
+ intersection = (iflat * tflat ).sum ()
40
+
41
+ return 1 - ((2. * intersection + smooth ) / (iflat .sum () + tflat .sum () + smooth ))
42
+
43
+ """
44
+ Lovasz-Softmax and Jaccard hinge loss in PyTorch
45
+ Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
46
+ """
47
+
48
+ def lovasz_grad (gt_sorted ):
49
+ """
50
+ Computes gradient of the Lovasz extension w.r.t sorted errors
51
+ See Alg. 1 in paper
52
+ """
53
+ p = len (gt_sorted )
54
+ gts = gt_sorted .sum ()
55
+ intersection = gts - gt_sorted .float ().cumsum (0 )
56
+ union = gts + (1 - gt_sorted ).float ().cumsum (0 )
57
+ jaccard = 1. - intersection / union
58
+ if p > 1 : # cover 1-pixel case
59
+ jaccard [1 :p ] = jaccard [1 :p ] - jaccard [0 :- 1 ]
60
+ return jaccard
61
+
62
+
63
+ def iou_binary (preds , labels , EMPTY = 1. , ignore = None , per_image = True ):
64
+ """
65
+ IoU for foreground class
66
+ binary: 1 foreground, 0 background
67
+ """
68
+ if not per_image :
69
+ preds , labels = (preds ,), (labels ,)
70
+ ious = []
71
+ for pred , label in zip (preds , labels ):
72
+ intersection = ((label == 1 ) & (pred == 1 )).sum ()
73
+ union = ((label == 1 ) | ((pred == 1 ) & (label != ignore ))).sum ()
74
+ if not union :
75
+ iou = EMPTY
76
+ else :
77
+ iou = float (intersection ) / union
78
+ ious .append (iou )
79
+ iou = mean (ious ) # mean accross images if per_image
80
+ return 100 * iou
81
+
82
+
83
+ def iou (preds , labels , C , EMPTY = 1. , ignore = None , per_image = False ):
84
+ """
85
+ Array of IoU for each (non ignored) class
86
+ """
87
+ if not per_image :
88
+ preds , labels = (preds ,), (labels ,)
89
+ ious = []
90
+ for pred , label in zip (preds , labels ):
91
+ iou = []
92
+ for i in range (C ):
93
+ if i != ignore : # The ignored label is sometimes among predicted classes (ENet - CityScapes)
94
+ intersection = ((label == i ) & (pred == i )).sum ()
95
+ union = ((label == i ) | ((pred == i ) & (label != ignore ))).sum ()
96
+ if not union :
97
+ iou .append (EMPTY )
98
+ else :
99
+ iou .append (float (intersection ) / union )
100
+ ious .append (iou )
101
+ ious = map (mean , zip (* ious )) # mean accross images if per_image
102
+ return 100 * np .array (ious )
103
+
104
+
105
+ # --------------------------- BINARY LOSSES ---------------------------
106
+
107
+
108
+ def lovasz_hinge (logits , labels , per_image = True , ignore = None ):
109
+ """
110
+ Binary Lovasz hinge loss
111
+ logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
112
+ labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
113
+ per_image: compute the loss per image instead of per batch
114
+ ignore: void class id
115
+ """
116
+ if per_image :
117
+ loss = mean (lovasz_hinge_flat (* flatten_binary_scores (log .unsqueeze (0 ), lab .unsqueeze (0 ), ignore ))
118
+ for log , lab in zip (logits , labels ))
119
+ else :
120
+ loss = lovasz_hinge_flat (* flatten_binary_scores (logits , labels , ignore ))
121
+ return loss
122
+
123
+
124
+ def lovasz_hinge_flat (logits , labels ):
125
+ """
126
+ Binary Lovasz hinge loss
127
+ logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
128
+ labels: [P] Tensor, binary ground truth labels (0 or 1)
129
+ ignore: label to ignore
130
+ """
131
+ if len (labels ) == 0 :
132
+ # only void pixels, the gradients should be 0
133
+ return logits .sum () * 0.
134
+ signs = 2. * labels .float () - 1.
135
+ errors = (1. - logits * Variable (signs ))
136
+ errors_sorted , perm = torch .sort (errors , dim = 0 , descending = True )
137
+ perm = perm .data
138
+ gt_sorted = labels [perm ]
139
+ grad = lovasz_grad (gt_sorted )
140
+ loss = torch .dot (F .relu (errors_sorted ), Variable (grad ))
141
+ return loss
142
+
143
+
144
+ def flatten_binary_scores (scores , labels , ignore = None ):
145
+ """
146
+ Flattens predictions in the batch (binary case)
147
+ Remove labels equal to 'ignore'
148
+ """
149
+ scores = scores .view (- 1 )
150
+ labels = labels .view (- 1 )
151
+ if ignore is None :
152
+ return scores , labels
153
+ valid = (labels != ignore )
154
+ vscores = scores [valid ]
155
+ vlabels = labels [valid ]
156
+ return vscores , vlabels
157
+
158
+
159
+ class StableBCELoss (torch .nn .modules .Module ):
160
+ def __init__ (self ):
161
+ super (StableBCELoss , self ).__init__ ()
162
+ def forward (self , input , target ):
163
+ neg_abs = - input .abs ()
164
+ loss = input .clamp (min = 0 ) - input * target + (1 + neg_abs .exp ()).log ()
165
+ return loss .mean ()
166
+
167
+
168
+ def binary_xloss (logits , labels , ignore = None ):
169
+ """
170
+ Binary Cross entropy loss
171
+ logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
172
+ labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
173
+ ignore: void class id
174
+ """
175
+ logits , labels = flatten_binary_scores (logits , labels , ignore )
176
+ loss = StableBCELoss ()(logits , Variable (labels .float ()))
177
+ return loss
178
+
179
+
180
+ # --------------------------- MULTICLASS LOSSES ---------------------------
181
+
182
+
183
+ def lovasz_softmax (probas , labels , only_present = False , per_image = False , ignore = None ):
184
+ """
185
+ Multi-class Lovasz-Softmax loss
186
+ probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
187
+ labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
188
+ only_present: average only on classes present in ground truth
189
+ per_image: compute the loss per image instead of per batch
190
+ ignore: void class labels
191
+ """
192
+ if per_image :
193
+ loss = mean (lovasz_softmax_flat (* flatten_probas (prob .unsqueeze (0 ), lab .unsqueeze (0 ), ignore ), only_present = only_present )
194
+ for prob , lab in zip (probas , labels ))
195
+ else :
196
+ loss = lovasz_softmax_flat (* flatten_probas (probas , labels , ignore ), only_present = only_present )
197
+ return loss
198
+
199
+
200
+ def lovasz_softmax_flat (probas , labels , only_present = False ):
201
+ """
202
+ Multi-class Lovasz-Softmax loss
203
+ probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
204
+ labels: [P] Tensor, ground truth labels (between 0 and C - 1)
205
+ only_present: average only on classes present in ground truth
206
+ """
207
+ if probas .numel () == 0 :
208
+ # only void pixels, the gradients should be 0
209
+ return probas * 0.
210
+ C = probas .size (1 )
211
+
212
+ C = probas .size (1 )
213
+ losses = []
214
+ for c in range (C ):
215
+ fg = (labels == c ).float () # foreground for class c
216
+ if only_present and fg .sum () == 0 :
217
+ continue
218
+ errors = (Variable (fg ) - probas [:, c ]).abs ()
219
+ errors_sorted , perm = torch .sort (errors , 0 , descending = True )
220
+ perm = perm .data
221
+ fg_sorted = fg [perm ]
222
+ losses .append (torch .dot (errors_sorted , Variable (lovasz_grad (fg_sorted ))))
223
+ return mean (losses )
224
+
225
+
226
+ def flatten_probas (probas , labels , ignore = None ):
227
+ """
228
+ Flattens predictions in the batch
229
+ """
230
+ B , C , H , W = probas .size ()
231
+ probas = probas .permute (0 , 2 , 3 , 1 ).contiguous ().view (- 1 , C ) # B * H * W, C = P, C
232
+ labels = labels .view (- 1 )
233
+ if ignore is None :
234
+ return probas , labels
235
+ valid = (labels != ignore )
236
+ vprobas = probas [valid .nonzero ().squeeze ()]
237
+ vlabels = labels [valid ]
238
+ return vprobas , vlabels
239
+
240
+ def xloss (logits , labels , ignore = None ):
241
+ """
242
+ Cross entropy loss
243
+ """
244
+ return F .cross_entropy (logits , Variable (labels ), ignore_index = 255 )
245
+
246
+
247
+ # --------------------------- HELPER FUNCTIONS ---------------------------
248
+ def isnan (x ):
249
+ return x != x
250
+
251
+
252
+ def mean (l , ignore_nan = True , empty = 0 ):
253
+ """
254
+ nanmean compatible with generators.
255
+ """
256
+ l = iter (l )
257
+ if ignore_nan :
258
+ l = ifilterfalse (isnan , l )
259
+ try :
260
+ n = 1
261
+ acc = next (l )
262
+ except StopIteration :
263
+ if empty == 'raise' :
264
+ raise ValueError ('Empty mean' )
265
+ return empty
266
+ for n , v in enumerate (l , 2 ):
267
+ acc += v
268
+ if n == 1 :
269
+ return acc
270
+ return acc / n
0 commit comments