|
| 1 | +from __future__ import print_function |
| 2 | + |
| 3 | +import logging |
| 4 | +import mxnet as mx |
| 5 | +from mxnet import gluon, autograd, nd |
| 6 | +from mxnet.gluon import nn |
| 7 | +import numpy as np |
| 8 | +import json |
| 9 | +import time |
| 10 | +import re |
| 11 | +from mxnet.io import DataIter, DataBatch, DataDesc |
| 12 | +import bisect, random |
| 13 | +from collections import Counter |
| 14 | +from itertools import chain, islice |
| 15 | + |
| 16 | + |
| 17 | +logging.basicConfig(level=logging.DEBUG) |
| 18 | + |
| 19 | +# ------------------------------------------------------------ # |
| 20 | +# Training methods # |
| 21 | +# ------------------------------------------------------------ # |
| 22 | + |
| 23 | +def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir, hyperparameters, **kwargs): |
| 24 | + # retrieve the hyperparameters we set in notebook (with some defaults) |
| 25 | + batch_size = hyperparameters.get('batch_size', 8) |
| 26 | + epochs = hyperparameters.get('epochs', 2) |
| 27 | + learning_rate = hyperparameters.get('learning_rate', 0.01) |
| 28 | + log_interval = hyperparameters.get('log_interval', 1000) |
| 29 | + embedding_size = hyperparameters.get('embedding_size', 50) |
| 30 | + wd = hyperparameters.get('wd', 0.0001) |
| 31 | + |
| 32 | + if len(hosts) == 1: |
| 33 | + kvstore = 'device' if num_gpus > 0 else 'local' |
| 34 | + else: |
| 35 | + kvstore = 'dist_sync' |
| 36 | + |
| 37 | + ctx = mx.gpu() if num_gpus > 0 else mx.cpu() |
| 38 | + |
| 39 | + training_dir = channel_input_dirs['training'] |
| 40 | + train_sentences, train_labels, _ = get_dataset(training_dir + '/train') |
| 41 | + val_sentences, val_labels, _ = get_dataset(training_dir + '/test') |
| 42 | + |
| 43 | + num_classes = len(set(train_labels)) |
| 44 | + vocab = create_vocab(train_sentences) |
| 45 | + vocab_size = len(vocab) |
| 46 | + |
| 47 | + train_sentences = [[vocab.get(token, 1) for token in line if len(line)>0] for line in train_sentences] |
| 48 | + val_sentences = [[vocab.get(token, 1) for token in line if len(line)>0] for line in val_sentences] |
| 49 | + |
| 50 | + train_iterator = BucketSentenceIter(train_sentences, train_labels, batch_size) |
| 51 | + val_iterator = BucketSentenceIter(val_sentences, val_labels, batch_size) |
| 52 | + |
| 53 | + # define the network |
| 54 | + net = TextClassifier(vocab_size, embedding_size, num_classes) |
| 55 | + |
| 56 | + # Collect all parameters from net and its children, then initialize them. |
| 57 | + net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) |
| 58 | + # Trainer is for updating parameters with gradient. |
| 59 | + trainer = gluon.Trainer(net.collect_params(), 'adam', |
| 60 | + {'learning_rate': learning_rate}) |
| 61 | + metric = mx.metric.Accuracy() |
| 62 | + loss = gluon.loss.SoftmaxCrossEntropyLoss() |
| 63 | + |
| 64 | + for epoch in range(epochs): |
| 65 | + # reset data iterator and metric at begining of epoch. |
| 66 | + metric.reset() |
| 67 | + btic = time.time() |
| 68 | + i = 0 |
| 69 | + for batch in train_iterator: |
| 70 | + # Copy data to ctx if necessary |
| 71 | + data = batch.data[0].as_in_context(ctx) |
| 72 | + label = batch.label[0].as_in_context(ctx) |
| 73 | + |
| 74 | + # Start recording computation graph with record() section. |
| 75 | + # Recorded graphs can then be differentiated with backward. |
| 76 | + with autograd.record(): |
| 77 | + output = net(data) |
| 78 | + L = loss(output, label) |
| 79 | + L.backward() |
| 80 | + # take a gradient step with batch_size equal to data.shape[0] |
| 81 | + trainer.step(data.shape[0]) |
| 82 | + # update metric at last. |
| 83 | + metric.update([label], [output]) |
| 84 | + |
| 85 | + if i % log_interval == 0 and i > 0: |
| 86 | + name, acc = metric.get() |
| 87 | + print('[Epoch %d Batch %d] Training: %s=%f, %f samples/s' % |
| 88 | + (epoch, i, name, acc, batch_size / (time.time() - btic))) |
| 89 | + |
| 90 | + btic = time.time() |
| 91 | + i += 1 |
| 92 | + |
| 93 | + name, acc = metric.get() |
| 94 | + print('[Epoch %d] Training: %s=%f' % (epoch, name, acc)) |
| 95 | + |
| 96 | + name, val_acc = test(ctx, net, val_iterator) |
| 97 | + print('[Epoch %d] Validation: %s=%f' % (epoch, name, val_acc)) |
| 98 | + train_iterator.reset() |
| 99 | + return net, vocab |
| 100 | + |
| 101 | + |
| 102 | +class BucketSentenceIter(DataIter): |
| 103 | + """Simple bucketing iterator for language model. |
| 104 | + The label at each sequence step is the following token |
| 105 | + in the sequence. |
| 106 | + Parameters |
| 107 | + ---------- |
| 108 | + sentences : list of list of int |
| 109 | + Encoded sentences. |
| 110 | + labels : list of int |
| 111 | + Corresponding labels. |
| 112 | + batch_size : int |
| 113 | + Batch size of the data. |
| 114 | + buckets : list of int, optional |
| 115 | + Size of the data buckets. Automatically generated if None. |
| 116 | + invalid_label : int, optional |
| 117 | + Key for invalid label, e.g. <unk. The default is 0. |
| 118 | + dtype : str, optional |
| 119 | + Data type of the encoding. The default data type is 'float32'. |
| 120 | + data_name : str, optional |
| 121 | + Name of the data. The default name is 'data'. |
| 122 | + label_name : str, optional |
| 123 | + Name of the label. The default name is 'softmax_label'. |
| 124 | + layout : str, optional |
| 125 | + Format of data and label. 'NT' means (batch_size, length) |
| 126 | + and 'TN' means (length, batch_size). |
| 127 | + """ |
| 128 | + def __init__(self, sentences, labels, batch_size, buckets=None, invalid_label=0, |
| 129 | + data_name='data', label_name='softmax_label', dtype='float32', |
| 130 | + layout='NT'): |
| 131 | + super(BucketSentenceIter, self).__init__() |
| 132 | + if not buckets: |
| 133 | + buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences])) |
| 134 | + if j >= batch_size] |
| 135 | + buckets.sort() |
| 136 | + |
| 137 | + ndiscard = 0 |
| 138 | + self.data = [[] for _ in buckets] |
| 139 | + self.labels = [[] for _ in buckets] |
| 140 | + for i, sent in enumerate(sentences): |
| 141 | + buck = bisect.bisect_left(buckets, len(sent)) |
| 142 | + if buck == len(buckets): |
| 143 | + ndiscard += 1 |
| 144 | + continue |
| 145 | + buff = np.full((buckets[buck],), invalid_label, dtype=dtype) |
| 146 | + buff[:len(sent)] = sent |
| 147 | + self.data[buck].append(buff) |
| 148 | + self.labels[buck].append(labels[i]) |
| 149 | + |
| 150 | + self.data = [np.asarray(i, dtype=dtype) for i in self.data] |
| 151 | + self.labels = [np.asarray(i, dtype=dtype) for i in self.labels] |
| 152 | + |
| 153 | + print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard) |
| 154 | + |
| 155 | + self.batch_size = batch_size |
| 156 | + self.buckets = buckets |
| 157 | + self.data_name = data_name |
| 158 | + self.label_name = label_name |
| 159 | + self.dtype = dtype |
| 160 | + self.invalid_label = invalid_label |
| 161 | + self.nddata = [] |
| 162 | + self.ndlabel = [] |
| 163 | + self.major_axis = layout.find('N') |
| 164 | + self.layout = layout |
| 165 | + self.default_bucket_key = max(buckets) |
| 166 | + |
| 167 | + if self.major_axis == 0: |
| 168 | + self.provide_data = [DataDesc( |
| 169 | + name=self.data_name, shape=(batch_size, self.default_bucket_key), |
| 170 | + layout=self.layout)] |
| 171 | + self.provide_label = [DataDesc( |
| 172 | + name=self.label_name, shape=(batch_size,), |
| 173 | + layout=self.layout)] |
| 174 | + elif self.major_axis == 1: |
| 175 | + self.provide_data = [DataDesc( |
| 176 | + name=self.data_name, shape=(self.default_bucket_key, batch_size), |
| 177 | + layout=self.layout)] |
| 178 | + self.provide_label = [DataDesc( |
| 179 | + name=self.label_name, shape=(self.default_bucket_key, batch_size), |
| 180 | + layout=self.layout)] |
| 181 | + else: |
| 182 | + raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)") |
| 183 | + |
| 184 | + self.idx = [] |
| 185 | + for i, buck in enumerate(self.data): |
| 186 | + self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)]) |
| 187 | + self.curr_idx = 0 |
| 188 | + self.reset() |
| 189 | + |
| 190 | + def reset(self): |
| 191 | + """Resets the iterator to the beginning of the data.""" |
| 192 | + self.curr_idx = 0 |
| 193 | + random.shuffle(self.idx) |
| 194 | + for i in range(len(self.data)): |
| 195 | + data, labels = self.data[i], self.labels[i] |
| 196 | + p = np.random.permutation(len(data)) |
| 197 | + self.data[i], self.labels[i] = data[p], labels[p] |
| 198 | + |
| 199 | + self.nddata = [] |
| 200 | + self.ndlabel = [] |
| 201 | + for buck,label_buck in zip(self.data, self.labels): |
| 202 | + self.nddata.append(nd.array(buck, dtype=self.dtype)) |
| 203 | + self.ndlabel.append(nd.array(label_buck, dtype=self.dtype)) |
| 204 | + |
| 205 | + def next(self): |
| 206 | + """Returns the next batch of data.""" |
| 207 | + if self.curr_idx == len(self.idx): |
| 208 | + raise StopIteration |
| 209 | + i, j = self.idx[self.curr_idx] |
| 210 | + self.curr_idx += 1 |
| 211 | + |
| 212 | + if self.major_axis == 1: |
| 213 | + data = self.nddata[i][j:j+self.batch_size].T |
| 214 | + label = self.ndlabel[i][j:j+self.batch_size].T |
| 215 | + else: |
| 216 | + data = self.nddata[i][j:j+self.batch_size] |
| 217 | + label = self.ndlabel[i][j:j+self.batch_size] |
| 218 | + |
| 219 | + return DataBatch([data], [label], pad=0, |
| 220 | + bucket_key=self.buckets[i], |
| 221 | + provide_data=[DataDesc( |
| 222 | + name=self.data_name, shape=data.shape, |
| 223 | + layout=self.layout)], |
| 224 | + provide_label=[DataDesc( |
| 225 | + name=self.label_name, shape=label.shape, |
| 226 | + layout=self.layout)]) |
| 227 | + |
| 228 | + |
| 229 | +class TextClassifier(gluon.HybridBlock): |
| 230 | + def __init__(self, vocab_size, embedding_size, classes, **kwargs): |
| 231 | + super(TextClassifier, self).__init__(**kwargs) |
| 232 | + with self.name_scope(): |
| 233 | + self.dense = gluon.nn.Dense(classes) |
| 234 | + self.embedding = gluon.nn.Embedding(input_dim=vocab_size, output_dim=embedding_size) |
| 235 | + |
| 236 | + def hybrid_forward(self, F, x): |
| 237 | + x = self.embedding(x) |
| 238 | + x = F.mean(x, axis=1) |
| 239 | + x = self.dense(x) |
| 240 | + return x |
| 241 | + |
| 242 | + |
| 243 | +def get_dataset(filename): |
| 244 | + labels = [] |
| 245 | + sentences = [] |
| 246 | + max_length = -1 |
| 247 | + with open(filename) as f: |
| 248 | + for line in f: |
| 249 | + tokens = line.split() |
| 250 | + label = int(tokens[0]) |
| 251 | + words = tokens[1:] |
| 252 | + max_length = max(max_length, len(words)) |
| 253 | + labels.append(label) |
| 254 | + sentences.append(words) |
| 255 | + return sentences, labels, max_length |
| 256 | + |
| 257 | + |
| 258 | +def create_vocab(sentences, min_count=5, num_words = 100000): |
| 259 | + BOS_SYMBOL = "<s>" |
| 260 | + EOS_SYMBOL = "</s>" |
| 261 | + UNK_SYMBOL = "<unk>" |
| 262 | + PAD_SYMBOL = "<pad>" |
| 263 | + PAD_ID = 0 |
| 264 | + TOKEN_SEPARATOR = " " |
| 265 | + VOCAB_SYMBOLS = [PAD_SYMBOL, UNK_SYMBOL, BOS_SYMBOL, EOS_SYMBOL] |
| 266 | + VOCAB_ENCODING = "utf-8" |
| 267 | + vocab_symbols_set = set(VOCAB_SYMBOLS) |
| 268 | + raw_vocab = Counter(token for line in sentences for token in line) |
| 269 | + pruned_vocab = sorted(((c, w) for w, c in raw_vocab.items() if c >= min_count), reverse=True) |
| 270 | + vocab = islice((w for c, w in pruned_vocab), num_words) |
| 271 | + word_to_id = {word: idx for idx, word in enumerate(chain(VOCAB_SYMBOLS, vocab))} |
| 272 | + return word_to_id |
| 273 | + |
| 274 | + |
| 275 | +def vocab_to_json(vocab, path): |
| 276 | + with open(path, "w") as out: |
| 277 | + json.dump(vocab, out, indent=4, ensure_ascii=False) |
| 278 | + print('Vocabulary saved to "%s"', path) |
| 279 | + |
| 280 | + |
| 281 | +def vocab_from_json(path): |
| 282 | + with open(path) as inp: |
| 283 | + vocab = json.load(inp) |
| 284 | + print('Vocabulary (%d words) loaded from "%s"', len(vocab), path) |
| 285 | + return vocab |
| 286 | + |
| 287 | + |
| 288 | +def save(net, model_dir): |
| 289 | + # save the model |
| 290 | + net, vocab = net |
| 291 | + y = net(mx.sym.var('data')) |
| 292 | + y.save('%s/model.json' % model_dir) |
| 293 | + net.collect_params().save('%s/model.params' % model_dir) |
| 294 | + vocab_to_json(vocab, '%s/vocab.json' % model_dir) |
| 295 | + |
| 296 | + |
| 297 | +def test(ctx, net, val_data): |
| 298 | + val_data.reset() |
| 299 | + metric = mx.metric.Accuracy() |
| 300 | + for batch in val_data: |
| 301 | + data = batch.data[0].as_in_context(ctx) |
| 302 | + label = batch.label[0].as_in_context(ctx) |
| 303 | + output = net(data) |
| 304 | + metric.update([label], [output]) |
| 305 | + return metric.get() |
| 306 | + |
| 307 | + |
| 308 | +# ------------------------------------------------------------ # |
| 309 | +# Hosting methods # |
| 310 | +# ------------------------------------------------------------ # |
| 311 | + |
| 312 | +def model_fn(model_dir): |
| 313 | + """ |
| 314 | + Load the gluon model. Called once when hosting service starts. |
| 315 | +
|
| 316 | + :param: model_dir The directory where model files are stored. |
| 317 | + :return: a model (in this case a Gluon network) |
| 318 | + """ |
| 319 | + symbol = mx.sym.load('%s/model.json' % model_dir) |
| 320 | + vocab = vocab_from_json('%s/vocab.json' % model_dir) |
| 321 | + outputs = mx.symbol.softmax(data=symbol, name='softmax_label') |
| 322 | + inputs = mx.sym.var('data') |
| 323 | + param_dict = gluon.ParameterDict('model_') |
| 324 | + net = gluon.SymbolBlock(outputs, inputs, param_dict) |
| 325 | + net.load_params('%s/model.params' % model_dir, ctx=mx.cpu()) |
| 326 | + return net, vocab |
| 327 | + |
| 328 | + |
| 329 | +def transform_fn(net, data, input_content_type, output_content_type): |
| 330 | + """ |
| 331 | + Transform a request using the Gluon model. Called once per request. |
| 332 | +
|
| 333 | + :param net: The Gluon model. |
| 334 | + :param data: The request payload. |
| 335 | + :param input_content_type: The request content type. |
| 336 | + :param output_content_type: The (desired) response content type. |
| 337 | + :return: response payload and content type. |
| 338 | + """ |
| 339 | + # we can use content types to vary input/output handling, but |
| 340 | + # here we just assume json for both |
| 341 | + net, vocab = net |
| 342 | + parsed = json.loads(data) |
| 343 | + outputs = [] |
| 344 | + for row in parsed: |
| 345 | + tokens = [vocab.get(token, 1) for token in row.split()] |
| 346 | + nda = mx.nd.array([tokens]) |
| 347 | + output = net(nda) |
| 348 | + prediction = mx.nd.argmax(output, axis=1) |
| 349 | + outputs.append(int(prediction.asscalar())) |
| 350 | + response_body = json.dumps(outputs) |
| 351 | + return response_body, output_content_type |
0 commit comments