|
17 | 17 |
|
18 | 18 |
|
19 | 19 | ################################################################################ |
20 | | -### Help functions for model architecture |
| 20 | +# Help functions for model architecture |
21 | 21 | ################################################################################ |
22 | 22 |
|
23 | 23 | # GlobalParams and BlockArgs: Two namedtuples |
|
50 | 50 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) |
51 | 51 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) |
52 | 52 |
|
53 | | - |
54 | | -# An ordinary implementation of Swish function |
55 | | -class Swish(nn.Module): |
56 | | - def forward(self, x): |
57 | | - return x * torch.sigmoid(x) |
| 53 | +# Swish activation function |
| 54 | +if hasattr(nn, 'SiLU'): |
| 55 | + Swish = nn.SiLU |
| 56 | +else: |
| 57 | + # For compatibility with old PyTorch versions |
| 58 | + class Swish(nn.Module): |
| 59 | + def forward(self, x): |
| 60 | + return x * torch.sigmoid(x) |
58 | 61 |
|
59 | 62 |
|
60 | 63 | # A memory-efficient implementation of Swish function |
@@ -97,10 +100,10 @@ def round_filters(filters, global_params): |
97 | 100 | divisor = global_params.depth_divisor |
98 | 101 | min_depth = global_params.min_depth |
99 | 102 | filters *= multiplier |
100 | | - min_depth = min_depth or divisor # pay attention to this line when using min_depth |
| 103 | + min_depth = min_depth or divisor # pay attention to this line when using min_depth |
101 | 104 | # follow the formula transferred from official TensorFlow implementation |
102 | 105 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) |
103 | | - if new_filters < 0.9 * filters: # prevent rounding by more than 10% |
| 106 | + if new_filters < 0.9 * filters: # prevent rounding by more than 10% |
104 | 107 | new_filters += divisor |
105 | 108 | return int(new_filters) |
106 | 109 |
|
@@ -234,7 +237,7 @@ def forward(self, x): |
234 | 237 | ih, iw = x.size()[-2:] |
235 | 238 | kh, kw = self.weight.size()[-2:] |
236 | 239 | sh, sw = self.stride |
237 | | - oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! |
| 240 | + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! |
238 | 241 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) |
239 | 242 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) |
240 | 243 | if pad_h > 0 or pad_w > 0: |
@@ -312,6 +315,7 @@ def forward(self, x): |
312 | 315 | return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, |
313 | 316 | self.dilation, self.ceil_mode, self.return_indices) |
314 | 317 |
|
| 318 | + |
315 | 319 | class MaxPool2dStaticSamePadding(nn.MaxPool2d): |
316 | 320 | """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. |
317 | 321 | The padding mudule is calculated in construction function, then used in forward. |
@@ -344,7 +348,7 @@ def forward(self, x): |
344 | 348 |
|
345 | 349 |
|
346 | 350 | ################################################################################ |
347 | | -### Helper functions for loading model params |
| 351 | +# Helper functions for loading model params |
348 | 352 | ################################################################################ |
349 | 353 |
|
350 | 354 | # BlockDecoder: A Class for encoding and decoding BlockArgs |
@@ -577,7 +581,7 @@ def get_model_params(model_name, override_params): |
577 | 581 | # TODO: add the petrained weights url map of 'efficientnet-l2' |
578 | 582 |
|
579 | 583 |
|
580 | | -def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False): |
| 584 | +def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True): |
581 | 585 | """Loads pretrained weights from weights path or download using url. |
582 | 586 | |
583 | 587 | Args: |
@@ -608,4 +612,5 @@ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, |
608 | 612 | ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) |
609 | 613 | assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) |
610 | 614 |
|
611 | | - print('Loaded pretrained weights for {}'.format(model_name)) |
| 615 | + if verbose: |
| 616 | + print('Loaded pretrained weights for {}'.format(model_name)) |
0 commit comments