@@ -388,18 +388,18 @@ def __init__(
388
388
mlp_layer : nn .Module = Mlp ,
389
389
mlp_ratio : float = 4. ,
390
390
mlp_act_layer : nn .Module = QuickGELU ,
391
- use_cls_token : Tuple [ bool , ...] = ( False , False , True ) ,
391
+ use_cls_token : bool = True ,
392
392
drop_rate : float = 0. ,
393
393
) -> None :
394
394
super ().__init__ ()
395
395
num_stages = len (dims )
396
396
assert num_stages == len (depths ) == len (embed_kernel_size ) == len (embed_stride )
397
- assert num_stages == len (embed_padding ) == len (num_heads )== len ( use_cls_token )
397
+ assert num_stages == len (embed_padding ) == len (num_heads )
398
398
self .num_classes = num_classes
399
399
self .num_features = dims [- 1 ]
400
400
self .feature_info = []
401
401
402
- self .use_cls_token = use_cls_token [ - 1 ]
402
+ self .use_cls_token = use_cls_token
403
403
404
404
dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
405
405
@@ -437,7 +437,7 @@ def __init__(
437
437
mlp_layer = mlp_layer ,
438
438
mlp_ratio = mlp_ratio ,
439
439
mlp_act_layer = mlp_act_layer ,
440
- use_cls_token = use_cls_token [ stage_idx ] ,
440
+ use_cls_token = use_cls_token and stage_idx == num_stages - 1 ,
441
441
)
442
442
in_chs = dim
443
443
stages .append (stage )
0 commit comments