Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d08a468

Browse files
Update cvt.py
1 parent 6c31ec1 commit d08a468

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

‎timm/models/cvt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,18 +388,18 @@ def __init__(
388388
mlp_layer: nn.Module = Mlp,
389389
mlp_ratio: float = 4.,
390390
mlp_act_layer: nn.Module = QuickGELU,
391-
use_cls_token: Tuple[bool, ...] = (False, False, True),
391+
use_cls_token: bool=True,
392392
drop_rate: float = 0.,
393393
) -> None:
394394
super().__init__()
395395
num_stages = len(dims)
396396
assert num_stages == len(depths) == len(embed_kernel_size) == len(embed_stride)
397-
assert num_stages == len(embed_padding) == len(num_heads)==len(use_cls_token)
397+
assert num_stages == len(embed_padding) == len(num_heads)
398398
self.num_classes = num_classes
399399
self.num_features = dims[-1]
400400
self.feature_info = []
401401

402-
self.use_cls_token = use_cls_token[-1]
402+
self.use_cls_token = use_cls_token
403403

404404
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
405405

@@ -437,7 +437,7 @@ def __init__(
437437
mlp_layer = mlp_layer,
438438
mlp_ratio = mlp_ratio,
439439
mlp_act_layer = mlp_act_layer,
440-
use_cls_token = use_cls_token[stage_idx],
440+
use_cls_token = use_cls_tokenandstage_idx==num_stages-1,
441441
)
442442
in_chs = dim
443443
stages.append(stage)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /