1
1
# coding=utf-8
2
2
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
3
3
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
4
+ # Modifications Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4
5
5
6
# Licensed under the Apache License, Version 2.0 (the "License");
6
7
# you may not use this file except in compliance with the License.
41
42
import json
42
43
import subprocess
43
44
44
- from tokenization import BertTokenizer
45
45
import modeling
46
46
from schedulers import PolyWarmUpScheduler
47
47
48
- from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
49
48
from utils import format_step , get_world_size , get_rank
50
49
from utils import is_main_process
51
50
from apex .parallel import DistributedDataParallel as DDP
55
54
import amp_C
56
55
import apex_C
57
56
from apex .amp import _amp_state
57
+ # SMP: Import smp library
58
58
import smdistributed .modelparallel .torch as smp
59
59
import configparser
60
60
@@ -89,10 +89,7 @@ def __call__(self, id):
89
89
90
90
def create_pretraining_dataset (input_file , max_pred_length , shared_list , args , worker_init ):
91
91
train_data = pretraining_dataset (input_file = input_file , max_pred_length = max_pred_length )
92
- if args .horovod > 0 :
93
- train_sampler = torch .utils .data .distributed .DistributedSampler (train_data , num_replicas = hvd .size (), rank = hvd .rank ())
94
- else :
95
- train_sampler = RandomSampler (train_data )
92
+ train_sampler = RandomSampler (train_data )
96
93
train_dataloader = DataLoader (train_data , sampler = train_sampler ,
97
94
batch_size = args .train_batch_size * args .n_gpu ,
98
95
num_workers = 4 , worker_init_fn = worker_init ,
@@ -290,7 +287,6 @@ def parse_arguments():
290
287
help = 'Disable tqdm progress bar' )
291
288
parser .add_argument ('--steps_this_run' , type = int , default = 1000 ,
292
289
help = 'If provided, only run this many steps before exiting' )
293
- parser .add_argument ('--horovod' , type = int , default = 0 )
294
290
parser .add_argument ('--ddp' , type = int , default = 0 )
295
291
parser .add_argument ('--smp' , type = int , default = 0 )
296
292
parser .add_argument ('--num_microbatches' , type = int , default = 1 )
@@ -323,21 +319,12 @@ def setup_training(args):
323
319
assert (torch .cuda .is_available ())
324
320
325
321
if args .smp > 0 :
326
- '''
327
- cfg = {
328
- 'microbatches': args.num_microbatches,
329
- 'placement_strategy': 'cluster',
330
- 'pipeline': args.pipeline,
331
- 'optimize': 'speed',
332
- 'partitions': 2,
333
- 'horovod': args.horovod > 0,
334
- 'ddp': args.ddp > 0,
335
- 'memory_weight': args.param_weight,
336
- 'overlapping_allreduce': args.overlapping_allreduce > 0,
337
- }
338
- '''
322
+ # Initialize SMP. The configuration is obtained from the parameters passed to
323
+ # the Sagemaker PyTorch estimator.
339
324
smp .init ()
340
325
326
+ # SMP: Set the device to the GPU ID used by the current process.
327
+ # Input tensors should be transferred to this device.
341
328
torch .cuda .set_device (smp .local_rank ())
342
329
device = torch .device ("cuda" , smp .local_rank ())
343
330
args .n_gpu = 1
@@ -407,6 +394,10 @@ def prepare_model_and_optimizer(args, device):
407
394
model = modeling .BertForPreTraining (config )
408
395
model .checkpoint_activations (args .checkpoint_activations )
409
396
if args .smp > 0 :
397
+ # SMP: Use the DistributedModel container to provide the model
398
+ # to be partitioned across different ranks. For the rest of the script,
399
+ # the returned DistributedModel object should be used in place of
400
+ # the model provided for DistributedModel class instantiation.
410
401
model = smp .DistributedModel (model )
411
402
412
403
checkpoint = None
@@ -419,6 +410,7 @@ def prepare_model_and_optimizer(args, device):
419
410
420
411
global_step = args .resume_step if not args .init_checkpoint else 0
421
412
413
+ # SMP: Load a model that was saved with smp.save
422
414
if not args .init_checkpoint :
423
415
checkpoint = smp .load (os .path .join (args .output_dir , "ckpt_{}.pt" .format (global_step )), partial = args .partial_checkpoint )
424
416
else :
@@ -441,6 +433,8 @@ def prepare_model_and_optimizer(args, device):
441
433
optimizer = FusedLAMB (optimizer_grouped_parameters ,
442
434
lr = args .learning_rate )
443
435
if args .smp > 0 :
436
+ # SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model
437
+ # Also provides APIs to obtain local optimizer state for the current mp_rank.
444
438
optimizer = smp .DistributedOptimizer (optimizer )
445
439
lr_scheduler = PolyWarmUpScheduler (optimizer ,
446
440
warmup = args .warmup_proportion ,
@@ -539,8 +533,6 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
539
533
param .grad = None
540
534
else :
541
535
if args .apply_optimizer > 0 :
542
- if args .horovod > 0 and args .overlapping_allreduce == 0 :
543
- smp .hvd_average_grads ()
544
536
optimizer .step ()
545
537
# optimizer.zero_grad()
546
538
for param in model .parameters ():
@@ -549,6 +541,8 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
549
541
550
542
return global_step
551
543
544
+ # SMP: Define smp step. Pass the necessary arguments for the train_step call
545
+ # and return any tensors needed outside
552
546
@smp .step
553
547
def smp_step (args , device , input_ids , segment_ids , input_mask , masked_lm_labels , next_sentence_labels , model , optimizer , criterion , step ):
554
548
rval = train_step (args , device , input_ids , segment_ids , input_mask , masked_lm_labels , next_sentence_labels , model , optimizer , criterion , step )
@@ -742,6 +736,7 @@ def main():
742
736
'data_loader' : None if global_step >= args .steps_this_run else train_dataloader }
743
737
if args .fp16 :
744
738
save_dict ['master params' ] = list (amp .master_params (optimizer ))
739
+ # SMP: Checkpoint mp_rank specific state
745
740
smp .save (save_dict , output_save_file , partial = True )
746
741
747
742
most_recent_ckpts_paths .append (output_save_file )
@@ -764,6 +759,7 @@ def main():
764
759
'data_loader' : None if global_step >= args .steps_this_run else train_dataloader }
765
760
if args .fp16 :
766
761
save_dict ['master params' ] = list (amp .master_params (optimizer ))
762
+ # SMP: Save a single checkpoint containing entire model parameters
767
763
smp .save (save_dict , output_save_file , partial = False )
768
764
smp .barrier ()
769
765
return args , final_loss , train_time_raw , global_step
0 commit comments