Skip to content

Commit c75523e

Browse files
authored
Merge pull request #120 from anirudh2290/changes_smp
SMP Changes: Doc, copyright and LICENSE changes
2 parents 8abb59e + 6774f83 commit c75523e

File tree

5 files changed

+23
-415
lines changed

5 files changed

+23
-415
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
BERT PyTorch
2+
3+
This repository includes software from https://github.com/huggingface/pytorch-pretrained-BERT
4+
licensed under the Apache License 2.0.

training/distributed_training/pytorch/model_parallel/bert/bert_example/sagemaker_rbk_pretrain.py renamed to training/distributed_training/pytorch/model_parallel/bert/bert_example/sagemaker_smp_pretrain.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# coding=utf-8
22
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
33
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
4+
# Modifications Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
45

56
# Licensed under the Apache License, Version 2.0 (the "License");
67
# you may not use this file except in compliance with the License.
@@ -41,11 +42,9 @@
4142
import json
4243
import subprocess
4344

44-
from tokenization import BertTokenizer
4545
import modeling
4646
from schedulers import PolyWarmUpScheduler
4747

48-
from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
4948
from utils import format_step, get_world_size, get_rank
5049
from utils import is_main_process
5150
from apex.parallel import DistributedDataParallel as DDP
@@ -55,6 +54,7 @@
5554
import amp_C
5655
import apex_C
5756
from apex.amp import _amp_state
57+
# SMP: Import smp library
5858
import smdistributed.modelparallel.torch as smp
5959
import configparser
6060

@@ -89,10 +89,7 @@ def __call__(self, id):
8989

9090
def create_pretraining_dataset(input_file, max_pred_length, shared_list, args, worker_init):
9191
train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length)
92-
if args.horovod > 0:
93-
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data, num_replicas=hvd.size(), rank=hvd.rank())
94-
else:
95-
train_sampler = RandomSampler(train_data)
92+
train_sampler = RandomSampler(train_data)
9693
train_dataloader = DataLoader(train_data, sampler=train_sampler,
9794
batch_size=args.train_batch_size * args.n_gpu,
9895
num_workers=4, worker_init_fn=worker_init,
@@ -290,7 +287,6 @@ def parse_arguments():
290287
help='Disable tqdm progress bar')
291288
parser.add_argument('--steps_this_run', type=int, default=1000,
292289
help='If provided, only run this many steps before exiting')
293-
parser.add_argument('--horovod', type=int, default=0)
294290
parser.add_argument('--ddp', type=int, default=0)
295291
parser.add_argument('--smp', type=int, default=0)
296292
parser.add_argument('--num_microbatches', type=int, default=1)
@@ -323,21 +319,12 @@ def setup_training(args):
323319
assert (torch.cuda.is_available())
324320

325321
if args.smp > 0:
326-
'''
327-
cfg = {
328-
'microbatches': args.num_microbatches,
329-
'placement_strategy': 'cluster',
330-
'pipeline': args.pipeline,
331-
'optimize': 'speed',
332-
'partitions': 2,
333-
'horovod': args.horovod > 0,
334-
'ddp': args.ddp > 0,
335-
'memory_weight': args.param_weight,
336-
'overlapping_allreduce': args.overlapping_allreduce > 0,
337-
}
338-
'''
322+
# Initialize SMP. The configuration is obtained from the parameters passed to
323+
# the Sagemaker PyTorch estimator.
339324
smp.init()
340325

326+
# SMP: Set the device to the GPU ID used by the current process.
327+
# Input tensors should be transferred to this device.
341328
torch.cuda.set_device(smp.local_rank())
342329
device = torch.device("cuda", smp.local_rank())
343330
args.n_gpu = 1
@@ -407,6 +394,10 @@ def prepare_model_and_optimizer(args, device):
407394
model = modeling.BertForPreTraining(config)
408395
model.checkpoint_activations(args.checkpoint_activations)
409396
if args.smp > 0:
397+
# SMP: Use the DistributedModel container to provide the model
398+
# to be partitioned across different ranks. For the rest of the script,
399+
# the returned DistributedModel object should be used in place of
400+
# the model provided for DistributedModel class instantiation.
410401
model = smp.DistributedModel(model)
411402

412403
checkpoint = None
@@ -419,6 +410,7 @@ def prepare_model_and_optimizer(args, device):
419410

420411
global_step = args.resume_step if not args.init_checkpoint else 0
421412

413+
# SMP: Load a model that was saved with smp.save
422414
if not args.init_checkpoint:
423415
checkpoint = smp.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), partial=args.partial_checkpoint)
424416
else:
@@ -441,6 +433,8 @@ def prepare_model_and_optimizer(args, device):
441433
optimizer = FusedLAMB(optimizer_grouped_parameters,
442434
lr=args.learning_rate)
443435
if args.smp > 0:
436+
# SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model
437+
# Also provides APIs to obtain local optimizer state for the current mp_rank.
444438
optimizer = smp.DistributedOptimizer(optimizer)
445439
lr_scheduler = PolyWarmUpScheduler(optimizer,
446440
warmup=args.warmup_proportion,
@@ -539,8 +533,6 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
539533
param.grad = None
540534
else:
541535
if args.apply_optimizer > 0:
542-
if args.horovod > 0 and args.overlapping_allreduce == 0:
543-
smp.hvd_average_grads()
544536
optimizer.step()
545537
# optimizer.zero_grad()
546538
for param in model.parameters():
@@ -549,6 +541,8 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
549541

550542
return global_step
551543

544+
# SMP: Define smp step. Pass the necessary arguments for the train_step call
545+
# and return any tensors needed outside
552546
@smp.step
553547
def smp_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step):
554548
rval = train_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step)
@@ -742,6 +736,7 @@ def main():
742736
'data_loader': None if global_step >= args.steps_this_run else train_dataloader}
743737
if args.fp16:
744738
save_dict['master params'] = list(amp.master_params(optimizer))
739+
# SMP: Checkpoint mp_rank specific state
745740
smp.save(save_dict, output_save_file, partial=True)
746741

747742
most_recent_ckpts_paths.append(output_save_file)
@@ -764,6 +759,7 @@ def main():
764759
'data_loader': None if global_step >= args.steps_this_run else train_dataloader}
765760
if args.fp16:
766761
save_dict['master params'] = list(amp.master_params(optimizer))
762+
# SMP: Save a single checkpoint containing entire model parameters
767763
smp.save(save_dict, output_save_file, partial=False)
768764
smp.barrier()
769765
return args, final_loss, train_time_raw, global_step

0 commit comments

Comments
 (0)