@@ -483,11 +483,6 @@ def fit(args, model, data_loader):
483483 # select gpu for horovod process
484484 if 'horovod' in args .kv_store :
485485 args .gpus = [args .gpus [hvd .local_rank ()]]
486- ctx = mx .gpu (hvd .local_rank ())
487- 488- tensor1 = mx .nd .zeros (shape = (1 ,), dtype = 'float32' , ctx = ctx )
489- tensor2 = mx .nd .zeros (shape = (1 ,), dtype = 'float32' , ctx = ctx )
490- tensor1 , tensor2 = hvd .grouped_allreduce ([tensor1 ,tensor2 ])
491486
492487 if args .amp :
493488 amp .init ()
@@ -579,6 +574,11 @@ def fit(args, model, data_loader):
579574 params = model .collect_params ()
580575 if params is not None :
581576 hvd .broadcast_parameters (params , root_rank = 0 )
577+ ctx = mx .gpu (hvd .local_rank ())
578+ tensor1 = mx .nd .zeros (shape = (1 ,), dtype = 'float32' , ctx = ctx )
579+ tensor2 = mx .nd .zeros (shape = (1 ,), dtype = 'float32' , ctx = ctx )
580+ tensor1 , tensor2 = hvd .grouped_allreduce ([tensor1 ,tensor2 ])
581+ 582582 global_metrics = CompositeMeter ()
583583 if args .mode in ['train_val' , 'train' ]:
584584 global_metrics .register_metric ('train.loss' , MinMeter ())
0 commit comments