Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit fa0bcd8

Browse files
Merge pull request #250 from rvandeghen/patch-1
Add new checkpoint
2 parents 1b181b1 + 02e295a commit fa0bcd8

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

‎efficientnet_pytorch/model.py‎

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,18 @@ def extract_endpoints(self, inputs):
238238
Returns:
239239
Dictionary of last intermediate features
240240
with reduction levels i in [1, 2, 3, 4, 5].
241-
242-
Example:
243-
>>> import torch
244-
>>> from efficientnet.model import EfficientNet
245-
>>> inputs = torch.rand(1, 3, 224, 224)
246-
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
247-
>>> endpoints = model.extract_endpoints(inputs)
248-
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
249-
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
250-
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
251-
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
252-
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
241+
Example:
242+
>>> import torch
243+
>>> from efficientnet.model import EfficientNet
244+
>>> inputs = torch.rand(1, 3, 224, 224)
245+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
246+
>>> endpoints = model.extract_endpoints(inputs)
247+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
248+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
249+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
250+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
251+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
252+
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
253253
"""
254254
endpoints = dict()
255255

@@ -265,6 +265,8 @@ def extract_endpoints(self, inputs):
265265
x = block(x, drop_connect_rate=drop_connect_rate)
266266
if prev_x.size(2) > x.size(2):
267267
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
268+
elif idx == len(self._blocks) - 1:
269+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
268270
prev_x = x
269271

270272
# Head

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /