@@ -238,18 +238,18 @@ def extract_endpoints(self, inputs):
238238 Returns:
239239 Dictionary of last intermediate features
240240 with reduction levels i in [1, 2, 3, 4, 5].
241-
242- Example:
243- >>> import torch
244- >>> from efficientnet.model import EfficientNet
245- >>> inputs = torch.rand(1, 3, 224, 224 )
246- >>> model = EfficientNet.from_pretrained('efficientnet-b0' )
247- >>> endpoints = model.extract_endpoints(inputs )
248- >>> print(endpoints['reduction_1 '].shape) # torch.Size([1, 16, 112, 112 ])
249- >>> print(endpoints['reduction_2 '].shape) # torch.Size([1, 24, 56, 56 ])
250- >>> print(endpoints['reduction_3 '].shape) # torch.Size([1, 40, 28, 28 ])
251- >>> print(endpoints['reduction_4 '].shape) # torch.Size([1, 112, 14, 14 ])
252- >>> print(endpoints['reduction_5 '].shape) # torch.Size([1, 1280, 7, 7])
241+ Example:
242+ >>> import torch
243+ >>> from efficientnet.model import EfficientNet
244+ >>> inputs = torch.rand(1, 3, 224, 224)
245+ >>> model = EfficientNet.from_pretrained('efficientnet-b0' )
246+ >>> endpoints = model.extract_endpoints(inputs )
247+ >>> print( endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112] )
248+ >>> print(endpoints['reduction_2 '].shape) # torch.Size([1, 24, 56, 56 ])
249+ >>> print(endpoints['reduction_3 '].shape) # torch.Size([1, 40, 28, 28 ])
250+ >>> print(endpoints['reduction_4 '].shape) # torch.Size([1, 112, 14, 14 ])
251+ >>> print(endpoints['reduction_5 '].shape) # torch.Size([1, 320, 7, 7 ])
252+ >>> print(endpoints['reduction_6 '].shape) # torch.Size([1, 1280, 7, 7])
253253 """
254254 endpoints = dict ()
255255
@@ -265,6 +265,8 @@ def extract_endpoints(self, inputs):
265265 x = block (x , drop_connect_rate = drop_connect_rate )
266266 if prev_x .size (2 ) > x .size (2 ):
267267 endpoints ['reduction_{}' .format (len (endpoints ) + 1 )] = prev_x
268+ elif idx == len (self ._blocks ) - 1 :
269+ endpoints ['reduction_{}' .format (len (endpoints ) + 1 )] = x
268270 prev_x = x
269271
270272 # Head
0 commit comments