Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 7e8b0d3

Browse files
Merge pull request #250 from rvandeghen/patch-1
Add new checkpoint
2 parents 45834ee + 75ca1bf commit 7e8b0d3

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

‎efficientnet_pytorch/model.py‎

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,18 @@ def extract_endpoints(self, inputs):
238238
Returns:
239239
Dictionary of last intermediate features
240240
with reduction levels i in [1, 2, 3, 4, 5].
241-
242-
Example:
243-
>>> import torch
244-
>>> from efficientnet.model import EfficientNet
245-
>>> inputs = torch.rand(1, 3, 224, 224)
246-
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
247-
>>> endpoints = model.extract_endpoints(inputs)
248-
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
249-
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
250-
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
251-
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
252-
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
241+
Example:
242+
>>> import torch
243+
>>> from efficientnet.model import EfficientNet
244+
>>> inputs = torch.rand(1, 3, 224, 224)
245+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
246+
>>> endpoints = model.extract_endpoints(inputs)
247+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
248+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
249+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
250+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
251+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
252+
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
253253
"""
254254
endpoints = dict()
255255

@@ -265,6 +265,8 @@ def extract_endpoints(self, inputs):
265265
x = block(x, drop_connect_rate=drop_connect_rate)
266266
if prev_x.size(2) > x.size(2):
267267
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
268+
elif idx == len(self._blocks) - 1:
269+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
268270
prev_x = x
269271

270272
# Head

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /