Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 54e2fb4

Browse files
Merge: [resnet/mxnet] Apply horovod patch for hvd init
2 parents 2a7c251 + 810bcf3 commit 54e2fb4

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

‎MxNet/Classification/RN50v1.5/dali.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def add_dali_args(parser):
3131
group.add_argument('--dali-validation-threads', type=int, default=10, help="number of threads" +\
3232
"per GPU for DALI for validation")
3333
group.add_argument('--dali-prefetch-queue', type=int, default=5, help="DALI prefetch queue depth")
34-
group.add_argument('--dali-nvjpeg-memory-padding', type=int, default=256, help="Memory padding value for nvJPEG (in MB)")
34+
group.add_argument('--dali-nvjpeg-memory-padding', type=int, default=64, help="Memory padding value for nvJPEG (in MB)")
3535
group.add_argument('--dali-fuse-decoder', type=int, default=1, help="0 or 1 whether to fuse decoder or not")
3636

3737
group.add_argument('--dali-nvjpeg-width-hint', type=int, default=5980, help="Width hint value for nvJPEG (in pixels)")

‎MxNet/Classification/RN50v1.5/fit.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,6 @@ def fit(args, model, data_loader):
483483
# select gpu for horovod process
484484
if 'horovod' in args.kv_store:
485485
args.gpus = [args.gpus[hvd.local_rank()]]
486-
ctx = mx.gpu(hvd.local_rank())
487-
488-
tensor1 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
489-
tensor2 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
490-
tensor1, tensor2 = hvd.grouped_allreduce([tensor1,tensor2])
491486

492487
if args.amp:
493488
amp.init()
@@ -579,6 +574,11 @@ def fit(args, model, data_loader):
579574
params = model.collect_params()
580575
if params is not None:
581576
hvd.broadcast_parameters(params, root_rank=0)
577+
ctx = mx.gpu(hvd.local_rank())
578+
tensor1 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
579+
tensor2 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
580+
tensor1, tensor2 = hvd.grouped_allreduce([tensor1,tensor2])
581+
582582
global_metrics = CompositeMeter()
583583
if args.mode in ['train_val', 'train']:
584584
global_metrics.register_metric('train.loss', MinMeter())

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /