Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0842093

Browse files
committed
update train.py
1 parent 15af4fc commit 0842093

File tree

1 file changed

+31
-1
lines changed
  • ImageNet/training_scripts/imagenet_training

1 file changed

+31
-1
lines changed

‎ImageNet/training_scripts/imagenet_training/train.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from contextlib import suppress
2626
from datetime import datetime
2727
import matplotlib.pyplot as plt
28-
28+
fromfvcore.nnimportFlopCountAnalysis,flop_count_table
2929

3030
import pickle
3131
import torch
@@ -208,6 +208,11 @@
208208
parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
209209
parser.add_argument('--netmode', type=int, default=1, help='which stride mode to use(1 to 5)')
210210

211+
parser.add_argument('--freeze-top', action='store_true', default=False, help='Freeze early layers up to idx36')
212+
parser.add_argument('--freeze-bot', action='store_true', default=False, help='Freeze late layers from idx36 to the end')
213+
parser.add_argument('--debug', action='store_true', default=False, help='whether to enable set_detect_anomaly() or not')
214+
parser.add_argument('--use-avgs', default='', type=str, metavar='PATH',help='Resume full model and optimizer state from checkpoint (default: none)')
215+
211216
# torch.autograd.set_detect_anomaly(True)
212217
def _parse_args():
213218
# Do we have a config file to parse?
@@ -304,6 +309,9 @@ def main():
304309
if args.fuser:
305310
set_jit_fuser(args.fuser)
306311

312+
if args.debug:
313+
torch.autograd.set_detect_anomaly(True)
314+
307315
# convert into int
308316
args.drop_rates = {int(key):float(value) for key,value in args.drop_rates.items()}
309317
print(f'args.drop_rates: {args.drop_rates}')
@@ -334,7 +342,10 @@ def main():
334342
model.set_grad_checkpointing(enable=True)
335343

336344
if args.local_rank == 0:
345+
flops = FlopCountAnalysis(model, torch.randn(size=(1,3,224,224)))
337346
_logger.info(f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()]):,}')
347+
_logger.info(f'Model {safe_model_name(args.model)} created, FLOPS count:{flops.total():,}')
348+
# _logger.info(f'Model {safe_model_name(args.model)} Flops table\n{flop_count_table(flops)}')
338349

339350
_logger.info(f'Model: {model}')
340351
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
@@ -570,6 +581,25 @@ def main():
570581
f.write(args_text)
571582

572583
try:
584+
# freeze top layers
585+
if args.freeze_bot:
586+
for n,v in model.features.named_parameters():
587+
if int(n.split('.')[0]) <=36:
588+
v.requires_grad=False
589+
elif args.freeze_top:
590+
for n,v in model.features.named_parameters():
591+
if int(n.split('.')[0]) >=36:
592+
v.requires_grad=False
593+
594+
if args.freeze_top or args.freeze_bot:
595+
for n,v in model.features.named_parameters():
596+
print(f'{n}.requires_grad: {v.requires_grad}')
597+
598+
if args.use_avgs:
599+
checkpoint_avgs = torch.load(args.use_avgs,map_location='cuda')
600+
model.load_state_dict(checkpoint_avgs)
601+
print(f'avg model loaded')
602+
573603
for epoch in range(start_epoch, num_epochs):
574604
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
575605
loader_train.sampler.set_epoch(epoch)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /