@@ -50,7 +50,7 @@ class MBConvBlock(nn.Module):
5050 def __init__ (self , block_args , global_params , image_size = None ):
5151 super ().__init__ ()
5252 self ._block_args = block_args
53- self ._bn_mom = 1 - global_params .batch_norm_momentum # pytorch's difference from tensorflow
53+ self ._bn_mom = 1 - global_params .batch_norm_momentum # pytorch's difference from tensorflow
5454 self ._bn_eps = global_params .batch_norm_epsilon
5555 self .has_se = (self ._block_args .se_ratio is not None ) and (0 < self ._block_args .se_ratio <= 1 )
5656 self .id_skip = block_args .id_skip # whether to use skip connection and drop connect
@@ -152,9 +152,7 @@ class EfficientNet(nn.Module):
152152 [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
153153
154154 Example:
155-
156-
157- import torch
155+ >>> import torch
158156 >>> from efficientnet.model import EfficientNet
159157 >>> inputs = torch.rand(1, 3, 224, 224)
160158 >>> model = EfficientNet.from_pretrained('efficientnet-b0')
@@ -198,7 +196,7 @@ def __init__(self, blocks_args=None, global_params=None):
198196 # The first block needs to take care of stride and filter size increase.
199197 self ._blocks .append (MBConvBlock (block_args , self ._global_params , image_size = image_size ))
200198 image_size = calculate_output_image_size (image_size , block_args .stride )
201- if block_args .num_repeat > 1 : # modify block_args to keep same output size
199+ if block_args .num_repeat > 1 : # modify block_args to keep same output size
202200 block_args = block_args ._replace (input_filters = block_args .output_filters , stride = 1 )
203201 for _ in range (block_args .num_repeat - 1 ):
204202 self ._blocks .append (MBConvBlock (block_args , self ._global_params , image_size = image_size ))
@@ -213,16 +211,18 @@ def __init__(self, blocks_args=None, global_params=None):
213211
214212 # Final linear layer
215213 self ._avg_pooling = nn .AdaptiveAvgPool2d (1 )
216- self ._dropout = nn .Dropout (self ._global_params .dropout_rate )
217- self ._fc = nn .Linear (out_channels , self ._global_params .num_classes )
214+ if self ._global_params .include_top :
215+ self ._dropout = nn .Dropout (self ._global_params .dropout_rate )
216+ self ._fc = nn .Linear (out_channels , self ._global_params .num_classes )
217+ 218+ # set activation to memory efficient swish by default
218219 self ._swish = MemoryEfficientSwish ()
219220
220221 def set_swish (self , memory_efficient = True ):
221222 """Sets swish function as memory efficient (for training) or standard (for export).
222223
223224 Args:
224225 memory_efficient (bool): Whether to use memory-efficient version of swish.
225-
226226 """
227227 self ._swish = MemoryEfficientSwish () if memory_efficient else Swish ()
228228 for block in self ._blocks :
@@ -261,17 +261,17 @@ def extract_endpoints(self, inputs):
261261 for idx , block in enumerate (self ._blocks ):
262262 drop_connect_rate = self ._global_params .drop_connect_rate
263263 if drop_connect_rate :
264- drop_connect_rate *= float (idx ) / len (self ._blocks ) # scale drop connect_rate
264+ drop_connect_rate *= float (idx ) / len (self ._blocks ) # scale drop connect_rate
265265 x = block (x , drop_connect_rate = drop_connect_rate )
266266 if prev_x .size (2 ) > x .size (2 ):
267- endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = prev_x
267+ endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = prev_x
268268 elif idx == len (self ._blocks ) - 1 :
269- endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = x
269+ endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = x
270270 prev_x = x
271271
272272 # Head
273273 x = self ._swish (self ._bn1 (self ._conv_head (x )))
274- endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = x
274+ endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = x
275275
276276 return endpoints
277277
@@ -292,7 +292,7 @@ def extract_features(self, inputs):
292292 for idx , block in enumerate (self ._blocks ):
293293 drop_connect_rate = self ._global_params .drop_connect_rate
294294 if drop_connect_rate :
295- drop_connect_rate *= float (idx ) / len (self ._blocks ) # scale drop connect_rate
295+ drop_connect_rate *= float (idx ) / len (self ._blocks ) # scale drop connect_rate
296296 x = block (x , drop_connect_rate = drop_connect_rate )
297297
298298 # Head
@@ -322,7 +322,7 @@ def forward(self, inputs):
322322
323323 @classmethod
324324 def from_name (cls , model_name , in_channels = 3 , ** override_params ):
325- """create an efficientnet model according to name.
325+ """Create an efficientnet model according to name.
326326
327327 Args:
328328 model_name (str): Name for efficientnet.
@@ -348,7 +348,7 @@ def from_name(cls, model_name, in_channels=3, **override_params):
348348 @classmethod
349349 def from_pretrained (cls , model_name , weights_path = None , advprop = False ,
350350 in_channels = 3 , num_classes = 1000 , ** override_params ):
351- """create an efficientnet model according to name.
351+ """Create an efficientnet model according to name.
352352
353353 Args:
354354 model_name (str): Name for efficientnet.
@@ -375,7 +375,8 @@ def from_pretrained(cls, model_name, weights_path=None, advprop=False,
375375 A pretrained efficientnet model.
376376 """
377377 model = cls .from_name (model_name , num_classes = num_classes , ** override_params )
378- load_pretrained_weights (model , model_name , weights_path = weights_path , load_fc = (num_classes == 1000 ), advprop = advprop )
378+ load_pretrained_weights (model , model_name , weights_path = weights_path ,
379+ load_fc = (num_classes == 1000 ), advprop = advprop )
379380 model ._change_in_channels (in_channels )
380381 return model
381382
0 commit comments