|
| 1 | +import argparse |
| 2 | +import time |
| 3 | +import numpy as np |
| 4 | +import mxnet as mx |
| 5 | + |
| 6 | +import gluonnlp as nlp |
| 7 | +from gluonnlp.data import SQuAD |
| 8 | +from model import BertForQALoss, BertForQA |
| 9 | +from data import SQuADTransform, preprocess_dataset |
| 10 | + |
| 11 | +import smdebug.mxnet as smd |
| 12 | +from smdebug import modes |
| 13 | + |
| 14 | +def get_dataloaders(batch_size, vocab, train_dataset_size, val_dataset_size): |
| 15 | + |
| 16 | + batchify_fn = nlp.data.batchify.Tuple( |
| 17 | + nlp.data.batchify.Stack(), |
| 18 | + nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]), |
| 19 | + nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]), |
| 20 | + nlp.data.batchify.Stack('float32'), |
| 21 | + nlp.data.batchify.Stack('float32'), |
| 22 | + nlp.data.batchify.Stack(), |
| 23 | + ) |
| 24 | + |
| 25 | + train_data = SQuAD("train", version='2.0')[:train_dataset_size] |
| 26 | + |
| 27 | + train_data_transform, _ = preprocess_dataset( |
| 28 | + train_data, SQuADTransform( |
| 29 | + nlp.data.BERTTokenizer(vocab=vocab, lower=True), |
| 30 | + max_seq_length=384, |
| 31 | + doc_stride=128, |
| 32 | + max_query_length=64, |
| 33 | + is_pad=True, |
| 34 | + is_training=True)) |
| 35 | + |
| 36 | + train_dataloader = mx.gluon.data.DataLoader( |
| 37 | + train_data_transform, batchify_fn=batchify_fn, |
| 38 | + batch_size=batch_size, num_workers=4, shuffle=True) |
| 39 | + |
| 40 | + #we only get 4 validation samples |
| 41 | + dev_data = SQuAD("dev", version='2.0')[:val_dataset_size] |
| 42 | + dev_data = mx.gluon.data.SimpleDataset(dev_data) |
| 43 | + |
| 44 | + dev_dataset = dev_data.transform( |
| 45 | + SQuADTransform( |
| 46 | + nlp.data.BERTTokenizer(vocab=vocab, lower=True), |
| 47 | + max_seq_length=384, |
| 48 | + doc_stride=128, |
| 49 | + max_query_length=64, |
| 50 | + is_pad=False, |
| 51 | + is_training=False)._transform, lazy=False) |
| 52 | + |
| 53 | + dev_data_transform, _ = preprocess_dataset( |
| 54 | + dev_data, SQuADTransform( |
| 55 | + nlp.data.BERTTokenizer(vocab=vocab, lower=True), |
| 56 | + max_seq_length=384, |
| 57 | + doc_stride=128, |
| 58 | + max_query_length=64, |
| 59 | + is_pad=False, |
| 60 | + is_training=False)) |
| 61 | + |
| 62 | + dev_dataloader = mx.gluon.data.DataLoader( |
| 63 | + dev_data_transform, |
| 64 | + batchify_fn=batchify_fn, |
| 65 | + num_workers=1, batch_size=batch_size, |
| 66 | + shuffle=False, last_batch='keep') |
| 67 | + |
| 68 | + return train_dataloader, dev_dataloader, dev_dataset |
| 69 | + |
| 70 | +def train_model(epochs, batch_size, learning_rate, train_dataset_size, val_dataset_size): |
| 71 | + |
| 72 | + #Check if GPU available |
| 73 | + ctx = mx.gpu() |
| 74 | + |
| 75 | + #load petrained BERT model weights (trained on wiki dataset) |
| 76 | + bert, vocab = nlp.model.get_model( |
| 77 | + name='bert_12_768_12', |
| 78 | + dataset_name='book_corpus_wiki_en_uncased', |
| 79 | + vocab=None, |
| 80 | + pretrained='true', |
| 81 | + ctx=ctx, |
| 82 | + use_pooler=False, |
| 83 | + use_decoder=False, |
| 84 | + use_classifier=False, |
| 85 | + output_attention=True) |
| 86 | + |
| 87 | + #create BERT model for Question Answering |
| 88 | + net = BertForQA(bert=bert) |
| 89 | + net.span_classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx) |
| 90 | + |
| 91 | + #create smdebug hook |
| 92 | + hook = smd.Hook.create_from_json_file() |
| 93 | + |
| 94 | + hook.register_block(net) |
| 95 | + |
| 96 | + #loss function for BERT model training |
| 97 | + loss_function = BertForQALoss() |
| 98 | + |
| 99 | + #trainer |
| 100 | + trainer = mx.gluon.Trainer(net.collect_params(), |
| 101 | + 'bertadam', |
| 102 | + {'learning_rate': learning_rate}, |
| 103 | + update_on_kvstore=False) |
| 104 | + |
| 105 | + #create dataloader |
| 106 | + train_dataloader, dev_dataloader, dev_dataset = get_dataloaders(batch_size, vocab, train_dataset_size, val_dataset_size) |
| 107 | + |
| 108 | + #initialize model parameters |
| 109 | + for _, v in net.collect_params('.*beta|.*gamma|.*bias').items(): |
| 110 | + v.wd_mult = 0.0 |
| 111 | + |
| 112 | + params = [p for p in net.collect_params().values() |
| 113 | + if p.grad_req != 'null'] |
| 114 | + |
| 115 | + #start training loop |
| 116 | + for epoch_id in range(epochs): |
| 117 | + |
| 118 | + for batch_id, data in enumerate(train_dataloader): |
| 119 | + hook.set_mode(modes.TRAIN) |
| 120 | + with mx.autograd.record(): |
| 121 | + _, inputs, token_types, valid_length, start_label, end_label = data |
| 122 | + |
| 123 | + # forward pass |
| 124 | + out = net(inputs.astype('float32').as_in_context(ctx), |
| 125 | + token_types.astype('float32').as_in_context(ctx), |
| 126 | + valid_length.astype('float32').as_in_context(ctx)) |
| 127 | + |
| 128 | + #compute loss |
| 129 | + ls = loss_function(out, [ |
| 130 | + start_label.astype('float32').as_in_context(ctx), |
| 131 | + end_label.astype('float32').as_in_context(ctx)]).mean() |
| 132 | + |
| 133 | + #backpropagation |
| 134 | + ls.backward() |
| 135 | + nlp.utils.clip_grad_global_norm(params, 1) |
| 136 | + |
| 137 | + #update model parameters |
| 138 | + trainer.update(1) |
| 139 | + |
| 140 | + #validation loop |
| 141 | + hook.set_mode(modes.EVAL) |
| 142 | + for data in dev_dataloader: |
| 143 | + |
| 144 | + example_ids, inputs, token_types, valid_length, _, _ = data |
| 145 | + |
| 146 | + #forward pass |
| 147 | + out = net(inputs.astype('float32').as_in_context(ctx), |
| 148 | + token_types.astype('float32').as_in_context(ctx), |
| 149 | + valid_length.astype('float32').as_in_context(ctx)) |
| 150 | + |
| 151 | + #record input tokens |
| 152 | + input_tokens = np.array([]) |
| 153 | + for example_id in example_ids.asnumpy().tolist(): |
| 154 | + array = np.array(dev_dataset[example_id][0].tokens, dtype=np.str) |
| 155 | + array = array.reshape(1, array.shape[0]) |
| 156 | + input_tokens = np.append(input_tokens, array) |
| 157 | + |
| 158 | + if hook.get_collections()['all'].save_config.should_save_step(modes.EVAL, hook.mode_steps[modes.EVAL]): |
| 159 | + hook._write_raw_tensor_simple("input_tokens", input_tokens) |
| 160 | + |
| 161 | + |
| 162 | + |
| 163 | +if __name__ =='__main__': |
| 164 | + |
| 165 | + parser = argparse.ArgumentParser() |
| 166 | + |
| 167 | + # hyperparameters sent by the client are passed as command-line arguments to the script. |
| 168 | + parser.add_argument('--epochs', type=int, default=20) |
| 169 | + parser.add_argument('--batch_size', type=int, default=64) |
| 170 | + parser.add_argument('--learning_rate', type=float, default=0.001) |
| 171 | + parser.add_argument('--val_dataset_size', type=int, default=64) |
| 172 | + parser.add_argument('--train_dataset_size', type=int, default=1024) |
| 173 | + parser.add_argument('--smdebug_dir', type=str, default=None) |
| 174 | + |
| 175 | + #parse arguments |
| 176 | + args, _ = parser.parse_known_args() |
| 177 | + |
| 178 | + #train model |
| 179 | + model = train_model(epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, train_dataset_size=args.train_dataset_size, val_dataset_size=args.val_dataset_size) |
| 180 | + |
0 commit comments