Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 76afc05

Browse files
Update cvt.py
1 parent d08a468 commit 76afc05

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

‎timm/models/cvt.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def __init__(
400400
self.feature_info = []
401401

402402
self.use_cls_token = use_cls_token
403+
self.global_pool = 'token' if use_cls_token else 'avg'
403404

404405
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
405406

@@ -448,6 +449,21 @@ def __init__(
448449
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
449450

450451

452+
453+
@torch.jit.ignore
454+
def get_classifier(self) -> nn.Module:
455+
return self.head
456+
457+
def reset_classifier(self, num_classes: int, global_pool = None) -> None:
458+
self.num_classes = num_classes
459+
if global_pool is not None:
460+
assert global_pool in ('', 'avg', 'token')
461+
if global_pool == 'token' and not self.use_cls_token:
462+
assert False, 'Model not configured to use class token'
463+
self.global_pool = global_pool
464+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
465+
466+
451467
def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
452468
# nn.Sequential forward can't accept tuple intermediates
453469
# TODO grad checkpointing
@@ -457,12 +473,13 @@ def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
457473
return x
458474

459475
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
476+
# get feature map, not always used
460477
x = self._forward_features(x)
461478

462479
return x[0] if self.use_cls_token else x
463480

464481
def forward_head(self, x: torch.Tensor) -> torch.Tensor:
465-
if self.use_cls_token:
482+
if self.global_pool=='token':
466483
return self.head(self.norm(x[1].flatten(1)))
467484
else:
468485
return self.head(self.norm(x.mean(dim=(2,3))))

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /