Skip to content

Commit a7e9bc3

Browse files
NRauschmayrddavydenko
authored andcommitted
smdebug custom analysis: bert example (#998)
Adding an example that demonstrates how to monitor attention scores in bert model training with SageMaker debugger. The example uses the GluonNLP tutorial on finetuning BERT for Question Answering By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 0917140 commit a7e9bc3

File tree

8 files changed

+1645
-0
lines changed

8 files changed

+1645
-0
lines changed

sagemaker-debugger/model_specific_realtime_analysis/bert_attention_head_view/bert_attention_head_view.ipynb

Lines changed: 582 additions & 0 deletions
Large diffs are not rendered by default.

sagemaker-debugger/model_specific_realtime_analysis/bert_attention_head_view/entry_point/data.py

Lines changed: 527 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# coding: utf-8
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
"""BertForQA models."""
20+
21+
__all__ = ['BertForQA', 'BertForQALoss']
22+
23+
from mxnet.gluon import HybridBlock, loss, nn
24+
from mxnet.gluon.loss import Loss
25+
26+
27+
class BertForQA(HybridBlock):
28+
"""Model for SQuAD task with BERT.
29+
30+
The model feeds token ids and token type ids into BERT to get the
31+
pooled BERT sequence representation, then apply a Dense layer for QA task.
32+
33+
Parameters
34+
----------
35+
bert: BERTModel
36+
Bidirectional encoder with transformer.
37+
prefix : str or None
38+
See document of `mx.gluon.Block`.
39+
params : ParameterDict or None
40+
See document of `mx.gluon.Block`.
41+
"""
42+
43+
def __init__(self, bert, prefix=None, params=None):
44+
super(BertForQA, self).__init__(prefix=prefix, params=params)
45+
self.bert = bert
46+
with self.name_scope():
47+
self.span_classifier = nn.Dense(units=2, flatten=False)
48+
49+
def __call__(self, inputs, token_types, valid_length=None):
50+
#pylint: disable=arguments-differ, dangerous-default-value
51+
"""Generate the unnormalized score for the given the input sequences."""
52+
# XXX Temporary hack for hybridization as hybridblock does not support None inputs
53+
valid_length = [] if valid_length is None else valid_length
54+
return super(BertForQA, self).__call__(inputs, token_types, valid_length)
55+
56+
def hybrid_forward(self, F, inputs, token_types, valid_length=None):
57+
# pylint: disable=arguments-differ
58+
"""Generate the unnormalized score for the given the input sequences.
59+
60+
Parameters
61+
----------
62+
inputs : NDArray, shape (batch_size, seq_length)
63+
Input words for the sequences.
64+
token_types : NDArray, shape (batch_size, seq_length)
65+
Token types for the sequences, used to indicate whether the word belongs to the
66+
first sentence or the second one.
67+
valid_length : NDArray or None, shape (batch_size,)
68+
Valid length of the sequence. This is used to mask the padded tokens.
69+
70+
Returns
71+
-------
72+
outputs : NDArray
73+
Shape (batch_size, seq_length, 2)
74+
"""
75+
# XXX Temporary hack for hybridization as hybridblock does not support None inputs
76+
if isinstance(valid_length, list) and len(valid_length) == 0:
77+
valid_length = None
78+
bert_output = self.bert(inputs, token_types, valid_length)[0]
79+
80+
output = self.span_classifier(bert_output)
81+
return output
82+
83+
84+
class BertForQALoss(Loss):
85+
"""Loss for SQuAD task with BERT.
86+
87+
"""
88+
89+
def __init__(self, weight=None, batch_axis=0, **kwargs): # pylint: disable=unused-argument
90+
super(BertForQALoss, self).__init__(
91+
weight=None, batch_axis=0, **kwargs)
92+
self.loss = loss.SoftmaxCELoss()
93+
94+
def hybrid_forward(self, F, pred, label): # pylint: disable=arguments-differ
95+
"""
96+
Parameters
97+
----------
98+
pred : NDArray, shape (batch_size, seq_length, 2)
99+
BERTSquad forward output.
100+
label : list, length is 2, each shape is (batch_size,1)
101+
label[0] is the starting position of the answer,
102+
label[1] is the ending position of the answer.
103+
104+
Returns
105+
-------
106+
outputs : NDArray
107+
Shape (batch_size,)
108+
"""
109+
pred = F.split(pred, axis=2, num_outputs=2)
110+
start_pred = pred[0].reshape((0, -3))
111+
start_label = label[0]
112+
end_pred = pred[1].reshape((0, -3))
113+
end_label = label[1]
114+
return (self.loss(start_pred, start_label) + self.loss(
115+
end_pred, end_label)) / 2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
gluonnlp
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import argparse
2+
import time
3+
import numpy as np
4+
import mxnet as mx
5+
6+
import gluonnlp as nlp
7+
from gluonnlp.data import SQuAD
8+
from model import BertForQALoss, BertForQA
9+
from data import SQuADTransform, preprocess_dataset
10+
11+
import smdebug.mxnet as smd
12+
from smdebug import modes
13+
14+
def get_dataloaders(batch_size, vocab, train_dataset_size, val_dataset_size):
15+
16+
batchify_fn = nlp.data.batchify.Tuple(
17+
nlp.data.batchify.Stack(),
18+
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]),
19+
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]),
20+
nlp.data.batchify.Stack('float32'),
21+
nlp.data.batchify.Stack('float32'),
22+
nlp.data.batchify.Stack(),
23+
)
24+
25+
train_data = SQuAD("train", version='2.0')[:train_dataset_size]
26+
27+
train_data_transform, _ = preprocess_dataset(
28+
train_data, SQuADTransform(
29+
nlp.data.BERTTokenizer(vocab=vocab, lower=True),
30+
max_seq_length=384,
31+
doc_stride=128,
32+
max_query_length=64,
33+
is_pad=True,
34+
is_training=True))
35+
36+
train_dataloader = mx.gluon.data.DataLoader(
37+
train_data_transform, batchify_fn=batchify_fn,
38+
batch_size=batch_size, num_workers=4, shuffle=True)
39+
40+
#we only get 4 validation samples
41+
dev_data = SQuAD("dev", version='2.0')[:val_dataset_size]
42+
dev_data = mx.gluon.data.SimpleDataset(dev_data)
43+
44+
dev_dataset = dev_data.transform(
45+
SQuADTransform(
46+
nlp.data.BERTTokenizer(vocab=vocab, lower=True),
47+
max_seq_length=384,
48+
doc_stride=128,
49+
max_query_length=64,
50+
is_pad=False,
51+
is_training=False)._transform, lazy=False)
52+
53+
dev_data_transform, _ = preprocess_dataset(
54+
dev_data, SQuADTransform(
55+
nlp.data.BERTTokenizer(vocab=vocab, lower=True),
56+
max_seq_length=384,
57+
doc_stride=128,
58+
max_query_length=64,
59+
is_pad=False,
60+
is_training=False))
61+
62+
dev_dataloader = mx.gluon.data.DataLoader(
63+
dev_data_transform,
64+
batchify_fn=batchify_fn,
65+
num_workers=1, batch_size=batch_size,
66+
shuffle=False, last_batch='keep')
67+
68+
return train_dataloader, dev_dataloader, dev_dataset
69+
70+
def train_model(epochs, batch_size, learning_rate, train_dataset_size, val_dataset_size):
71+
72+
#Check if GPU available
73+
ctx = mx.gpu()
74+
75+
#load petrained BERT model weights (trained on wiki dataset)
76+
bert, vocab = nlp.model.get_model(
77+
name='bert_12_768_12',
78+
dataset_name='book_corpus_wiki_en_uncased',
79+
vocab=None,
80+
pretrained='true',
81+
ctx=ctx,
82+
use_pooler=False,
83+
use_decoder=False,
84+
use_classifier=False,
85+
output_attention=True)
86+
87+
#create BERT model for Question Answering
88+
net = BertForQA(bert=bert)
89+
net.span_classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
90+
91+
#create smdebug hook
92+
hook = smd.Hook.create_from_json_file()
93+
94+
hook.register_block(net)
95+
96+
#loss function for BERT model training
97+
loss_function = BertForQALoss()
98+
99+
#trainer
100+
trainer = mx.gluon.Trainer(net.collect_params(),
101+
'bertadam',
102+
{'learning_rate': learning_rate},
103+
update_on_kvstore=False)
104+
105+
#create dataloader
106+
train_dataloader, dev_dataloader, dev_dataset = get_dataloaders(batch_size, vocab, train_dataset_size, val_dataset_size)
107+
108+
#initialize model parameters
109+
for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
110+
v.wd_mult = 0.0
111+
112+
params = [p for p in net.collect_params().values()
113+
if p.grad_req != 'null']
114+
115+
#start training loop
116+
for epoch_id in range(epochs):
117+
118+
for batch_id, data in enumerate(train_dataloader):
119+
hook.set_mode(modes.TRAIN)
120+
with mx.autograd.record():
121+
_, inputs, token_types, valid_length, start_label, end_label = data
122+
123+
# forward pass
124+
out = net(inputs.astype('float32').as_in_context(ctx),
125+
token_types.astype('float32').as_in_context(ctx),
126+
valid_length.astype('float32').as_in_context(ctx))
127+
128+
#compute loss
129+
ls = loss_function(out, [
130+
start_label.astype('float32').as_in_context(ctx),
131+
end_label.astype('float32').as_in_context(ctx)]).mean()
132+
133+
#backpropagation
134+
ls.backward()
135+
nlp.utils.clip_grad_global_norm(params, 1)
136+
137+
#update model parameters
138+
trainer.update(1)
139+
140+
#validation loop
141+
hook.set_mode(modes.EVAL)
142+
for data in dev_dataloader:
143+
144+
example_ids, inputs, token_types, valid_length, _, _ = data
145+
146+
#forward pass
147+
out = net(inputs.astype('float32').as_in_context(ctx),
148+
token_types.astype('float32').as_in_context(ctx),
149+
valid_length.astype('float32').as_in_context(ctx))
150+
151+
#record input tokens
152+
input_tokens = np.array([])
153+
for example_id in example_ids.asnumpy().tolist():
154+
array = np.array(dev_dataset[example_id][0].tokens, dtype=np.str)
155+
array = array.reshape(1, array.shape[0])
156+
input_tokens = np.append(input_tokens, array)
157+
158+
if hook.get_collections()['all'].save_config.should_save_step(modes.EVAL, hook.mode_steps[modes.EVAL]):
159+
hook._write_raw_tensor_simple("input_tokens", input_tokens)
160+
161+
162+
163+
if __name__ =='__main__':
164+
165+
parser = argparse.ArgumentParser()
166+
167+
# hyperparameters sent by the client are passed as command-line arguments to the script.
168+
parser.add_argument('--epochs', type=int, default=20)
169+
parser.add_argument('--batch_size', type=int, default=64)
170+
parser.add_argument('--learning_rate', type=float, default=0.001)
171+
parser.add_argument('--val_dataset_size', type=int, default=64)
172+
parser.add_argument('--train_dataset_size', type=int, default=1024)
173+
parser.add_argument('--smdebug_dir', type=str, default=None)
174+
175+
#parse arguments
176+
args, _ = parser.parse_known_args()
177+
178+
#train model
179+
model = train_model(epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, train_dataset_size=args.train_dataset_size, val_dataset_size=args.val_dataset_size)
180+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from bokeh.plotting import show, figure
2+
from bokeh.models.annotations import Title
3+
from bokeh.models import ColumnDataSource, Label, Range1d
4+
from bokeh.io import show, output_notebook, push_notebook
5+
from bokeh.models.glyphs import Line
6+
import numpy as np
7+
8+
output_notebook()
9+
10+
class AttentionHeadView():
11+
def __init__(self,
12+
input_tokens=None,
13+
tensors=None,
14+
layer='bertencoder0_transformer0_multiheadattentioncell0_output_1',
15+
step=0,
16+
n_tokens=20):
17+
self.head = 0
18+
self.step = step
19+
self.input_tokens = input_tokens[:n_tokens]
20+
self.n_tokens = n_tokens
21+
self.tensors = tensors
22+
self.p = None
23+
self.layer = layer
24+
self.sources = []
25+
self.create()
26+
27+
def update(self):
28+
29+
tensor = self.tensors[self.layer][self.step][0, self.head, :, :]
30+
31+
counter = 0
32+
for x in range(self.n_tokens):
33+
for y in range(self.n_tokens):
34+
source = self.sources[counter]
35+
source.line_width = tensor[x, y] * 2
36+
counter += 1
37+
38+
def select_layer(self, layer):
39+
self.layer = layer
40+
self.update()
41+
push_notebook()
42+
43+
def select_head(self, head):
44+
self.head = head
45+
self.update()
46+
push_notebook()
47+
48+
def select_step(self, step):
49+
self.step = step
50+
self.update()
51+
push_notebook()
52+
53+
def create(self):
54+
55+
# set size of figure
56+
self.p = figure(width = 450,
57+
plot_height = 50 * self.n_tokens,
58+
x_range=Range1d(0, self.n_tokens + 2),
59+
y_range=Range1d(0, self.n_tokens))
60+
61+
self.p.xgrid.visible = False
62+
self.p.ygrid.visible = False
63+
self.p.axis.visible = False
64+
65+
x = np.zeros(self.n_tokens) + 2
66+
y = np.flip(np.arange(0, self.n_tokens), axis=0)
67+
68+
# set input tokens in plot
69+
for token, x_i, y_i in zip(self.input_tokens, x, y):
70+
text1 = Label(x = x_i - 1,
71+
y = y_i,
72+
text = token,
73+
text_font_size = '10pt')
74+
text2 = Label(x = x_i + 10,
75+
y = y_i,
76+
text = token,
77+
text_font_size = '10pt')
78+
self.p.add_layout(text2)
79+
self.p.add_layout(text1)
80+
81+
tensor = self.tensors[self.layer][self.step][0, self.head, :, :]
82+
83+
#plot attention weights
84+
for x in range(self.n_tokens):
85+
for y in range(self.n_tokens):
86+
source = ColumnDataSource(data=dict(x=[2, 12],
87+
y=[self.n_tokens - x - 1, self.n_tokens - y - 1]))
88+
line = Line(x="x", y="y", line_width=tensor[x, y], line_color = "blue")
89+
self.p.add_glyph(source, line)
90+
self.sources.append(line)
91+
92+
show(self.p, notebook_handle=True)

0 commit comments

Comments
 (0)