@@ -421,14 +421,25 @@ def build_model_with_cfg(
421421 if 'feature_cls' in kwargs :
422422 feature_cfg ['feature_cls' ] = kwargs .pop ('feature_cls' )
423423
424+ # use meta-device init to speed up loading pretrained weights.
425+ # device context manager is only available for PyTorch>=2.0
426+ # when num_classes is changed, we rely on __init__() logic to initialize head weights.
427+ # thus, we can't use meta-device init in that case.
428+ num_classes = 0 if features else kwargs .get ("num_classes" , pretrained_cfg ["num_classes" ])
429+ use_meta_init = (
430+ pretrained
431+ and hasattr (torch .device ("meta" ), "__enter__" )
432+ and (num_classes == 0 or num_classes == pretrained_cfg ["num_classes" ])
433+ )
434+ 424435 # Instantiate the model
425- meta_device = torch .device ("meta" )
426- with meta_device if hasattr (meta_device , "__enter__" ) and pretrained else nullcontext ():
436+ with torch .device ("meta" ) if use_meta_init else nullcontext ():
427437 if model_cfg is None :
428438 model = model_cls (** kwargs )
429439 else :
430440 model = model_cls (cfg = model_cfg , ** kwargs )
431- if pretrained :
441+ 442+ if use_meta_init :
432443 # .to_empty() will also move cpu params/buffers to uninitialized storage.
433444 # this is problematic for non-persistent buffers, since they don't get loaded
434445 # from pretrained weights later (not part of state_dict). hence, we have
0 commit comments