|
133 | 133 | "Please note that it is a common practise to split words into subwords using Byte Pair Encoding (BPE). Please refer to [this](https://github.com/awslabs/sockeye/tree/master/tutorials/wmt) tutorial if you are interested in performing BPE."
|
134 | 134 | ]
|
135 | 135 | },
|
| 136 | + { |
| 137 | + "cell_type": "markdown", |
| 138 | + "metadata": {}, |
| 139 | + "source": [ |
| 140 | + "Since training on the whole dataset might take several hours/days, for this demo, let us train on the **first 10,000 lines only**. Don't run the next cell if you want to train on the complete dataset." |
| 141 | + ] |
| 142 | + }, |
| 143 | + { |
| 144 | + "cell_type": "code", |
| 145 | + "execution_count": null, |
| 146 | + "metadata": { |
| 147 | + "collapsed": true |
| 148 | + }, |
| 149 | + "outputs": [], |
| 150 | + "source": [ |
| 151 | + "!head -n 10000 corpus.tc.en > corpus.tc.en.small\n", |
| 152 | + "!head -n 10000 corpus.tc.de > corpus.tc.de.small" |
| 153 | + ] |
| 154 | + }, |
136 | 155 | {
|
137 | 156 | "cell_type": "markdown",
|
138 | 157 | "metadata": {},
|
|
155 | 174 | "cell_type": "markdown",
|
156 | 175 | "metadata": {},
|
157 | 176 | "source": [
|
158 |
| - "Let's do the preprocessing now. Sit back and relax as the script below might take around 10-15 min" |
| 177 | + "The cell below does the preprocessing. If you are using the complete dataset, the script might take around 10-15 min on an m4.xlarge notebook instance. Remove \".small\" from the file names for training on full datasets." |
159 | 178 | ]
|
160 | 179 | },
|
161 | 180 | {
|
162 | 181 | "cell_type": "code",
|
163 | 182 | "execution_count": null,
|
164 |
| - "metadata": { |
165 |
| - "collapsed": true |
166 |
| - }, |
| 183 | + "metadata": {}, |
167 | 184 | "outputs": [],
|
168 | 185 | "source": [
|
169 | 186 | "%%time\n",
|
170 | 187 | "%%bash\n",
|
171 | 188 | "python3 create_vocab_proto.py \\\n",
|
172 |
| - " --train-source corpus.tc.en \\\n", |
173 |
| - " --train-target corpus.tc.de \\\n", |
| 189 | + " --train-source corpus.tc.en.small \\\n", |
| 190 | + " --train-target corpus.tc.de.small \\\n", |
174 | 191 | " --val-source validation/newstest2014.tc.en \\\n",
|
175 | 192 | " --val-target validation/newstest2014.tc.de"
|
176 | 193 | ]
|
|
222 | 239 | {
|
223 | 240 | "cell_type": "code",
|
224 | 241 | "execution_count": null,
|
225 |
| - "metadata": { |
226 |
| - "collapsed": true |
227 |
| - }, |
| 242 | + "metadata": {}, |
228 | 243 | "outputs": [],
|
229 | 244 | "source": [
|
230 | 245 | "containers = {'us-west-2': '433757028032.dkr.ecr.us-west-2.amazonaws.com/seq2seq:latest',\n",
|
|
245 | 260 | {
|
246 | 261 | "cell_type": "code",
|
247 | 262 | "execution_count": null,
|
248 |
| - "metadata": { |
249 |
| - "collapsed": true |
250 |
| - }, |
| 263 | + "metadata": {}, |
251 | 264 | "outputs": [],
|
252 | 265 | "source": [
|
253 |
| - "job_name = 'seq2seq-en-de-small-p2-16x-' + strftime(\"%Y-%m-%d-%H\", gmtime())\n", |
| 266 | + "job_name = 'seq2seq-en-de-p2-xlarge-' + strftime(\"%Y-%m-%d-%H\", gmtime())\n", |
254 | 267 | "print(\"Training job\", job_name)\n",
|
255 | 268 | "\n",
|
256 | 269 | "create_training_params = \\\n",
|
|
266 | 279 | " \"ResourceConfig\": {\n",
|
267 | 280 | " # Seq2Seq does not support multiple machines. Currently, it only supports single machine, multiple GPUs\n",
|
268 | 281 | " \"InstanceCount\": 1,\n",
|
269 |
| - " \"InstanceType\": \"ml.p2.16xlarge\", # We suggest one of [\"ml.p2.16xlarge\", \"ml.p2.8xlarge\", \"ml.p2.xlarge\"]\n", |
| 282 | + " \"InstanceType\": \"ml.p2.xlarge\", # We suggest one of [\"ml.p2.16xlarge\", \"ml.p2.8xlarge\", \"ml.p2.xlarge\"]\n", |
270 | 283 | " \"VolumeSizeInGB\": 50\n",
|
271 | 284 | " },\n",
|
272 | 285 | " \"TrainingJobName\": job_name,\n",
|
|
275 | 288 | " \"max_seq_len_source\": \"60\",\n",
|
276 | 289 | " \"max_seq_len_target\": \"60\",\n",
|
277 | 290 | " \"optimized_metric\": \"bleu\",\n",
|
278 |
| - " \"batch_size\": \"256\",\n", |
| 291 | + " \"batch_size\": \"64\", # Please use a larger batch size (256 or 512) if using ml.p2.8xlarge or ml.p2.16xlarge\n", |
279 | 292 | " \"checkpoint_frequency_num_batches\": \"1000\",\n",
|
280 | 293 | " \"rnn_num_hidden\": \"512\",\n",
|
281 | 294 | " \"num_layers_encoder\": \"1\",\n",
|
282 | 295 | " \"num_layers_decoder\": \"1\",\n",
|
283 | 296 | " \"num_embed_source\": \"512\",\n",
|
284 | 297 | " \"num_embed_target\": \"512\",\n",
|
285 | 298 | " \"checkpoint_threshold\": \"3\",\n",
|
| 299 | + " \"max_num_batches\": \"2100\"\n", |
| 300 | + " # Training will stop after 2100 iterations/batches.\n", |
| 301 | + " # This is just for demo purposes. Remove the above parameter if you want a better model.\n", |
286 | 302 | " },\n",
|
287 | 303 | " \"StoppingCondition\": {\n",
|
288 | 304 | " \"MaxRuntimeInSeconds\": 48 * 3600\n",
|
|
331 | 347 | {
|
332 | 348 | "cell_type": "code",
|
333 | 349 | "execution_count": null,
|
334 |
| - "metadata": { |
335 |
| - "collapsed": true |
336 |
| - }, |
| 350 | + "metadata": {}, |
337 | 351 | "outputs": [],
|
338 | 352 | "source": [
|
339 | 353 | "status = sagemaker_client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
|
|
352 | 366 | "> Now wait for the training job to complete and proceed to the next step after you see model artifacts in your S3 bucket."
|
353 | 367 | ]
|
354 | 368 | },
|
| 369 | + { |
| 370 | + "cell_type": "markdown", |
| 371 | + "metadata": {}, |
| 372 | + "source": [ |
| 373 | + "You can jump to [Use a pretrained model](#Use-a-pretrained-model) as training might take some time." |
| 374 | + ] |
| 375 | + }, |
355 | 376 | {
|
356 | 377 | "cell_type": "markdown",
|
357 | 378 | "metadata": {},
|
|
366 | 387 | "- Perform Inference - Perform inference on some input data using the endpoint.\n",
|
367 | 388 | "\n",
|
368 | 389 | "### Create model\n",
|
369 |
| - "We now create a SageMaker Model from the training output. Using the model, we can then create an Endpoint Configuration.\n", |
370 |
| - "\n", |
371 |
| - "#### Note: Please uncomment and run the lines below if you want to use a pretrained model, as training might take several hours/days to complete." |
| 390 | + "We now create a SageMaker Model from the training output. Using the model, we can then create an Endpoint Configuration." |
372 | 391 | ]
|
373 | 392 | },
|
374 | 393 | {
|
|
379 | 398 | },
|
380 | 399 | "outputs": [],
|
381 | 400 | "source": [
|
| 401 | + "use_pretrained_model = False" |
| 402 | + ] |
| 403 | + }, |
| 404 | + { |
| 405 | + "cell_type": "markdown", |
| 406 | + "metadata": {}, |
| 407 | + "source": [ |
| 408 | + "### Use a pretrained model\n", |
| 409 | + "#### Please uncomment and run the cell below if you want to use a pretrained model, as training might take several hours/days to complete." |
| 410 | + ] |
| 411 | + }, |
| 412 | + { |
| 413 | + "cell_type": "code", |
| 414 | + "execution_count": null, |
| 415 | + "metadata": {}, |
| 416 | + "outputs": [], |
| 417 | + "source": [ |
| 418 | + "# use_pretrained_model = True\n", |
| 419 | + "# model_name = \"pretrained-en-de-model\"\n", |
382 | 420 | "# !curl https://s3-us-west-2.amazonaws.com/gsaur-seq2seq-data/seq2seq/eng-german/full-nb-translation-eng-german-p2-16x-2017-11-24-22-25-53/output/model.tar.gz > model.tar.gz\n",
|
383 | 421 | "# !curl https://s3-us-west-2.amazonaws.com/gsaur-seq2seq-data/seq2seq/eng-german/full-nb-translation-eng-german-p2-16x-2017-11-24-22-25-53/output/vocab.src.json > vocab.src.json\n",
|
384 | 422 | "# !curl https://s3-us-west-2.amazonaws.com/gsaur-seq2seq-data/seq2seq/eng-german/full-nb-translation-eng-german-p2-16x-2017-11-24-22-25-53/output/vocab.trg.json > vocab.trg.json\n",
|
385 | 423 | "# upload_to_s3(bucket, prefix, 'pretrained_model', 'model.tar.gz')\n",
|
386 |
| - "# use_pretrained_model = True\n", |
387 | 424 | "# model_data = \"s3://{}/{}/pretrained_model/model.tar.gz\".format(bucket, prefix)"
|
388 | 425 | ]
|
389 | 426 | },
|
390 | 427 | {
|
391 | 428 | "cell_type": "code",
|
392 | 429 | "execution_count": null,
|
393 |
| - "metadata": { |
394 |
| - "collapsed": true |
395 |
| - }, |
| 430 | + "metadata": {}, |
396 | 431 | "outputs": [],
|
397 | 432 | "source": [
|
398 | 433 | "%%time\n",
|
399 | 434 | "\n",
|
400 | 435 | "sage = boto3.client('sagemaker')\n",
|
401 | 436 | "\n",
|
402 |
| - "model_name=job_name\n", |
403 |
| - "print(model_name)\n", |
404 |
| - "\n", |
405 |
| - "info = sage.describe_training_job(TrainingJobName=job_name)\n", |
406 |
| - "\n", |
407 |
| - "try:\n", |
408 |
| - " if use_pretrained_model:\n", |
409 |
| - " model_data\n", |
410 |
| - "except:\n", |
| 437 | + "if not use_pretrained_model:\n", |
| 438 | + " info = sage.describe_training_job(TrainingJobName=job_name)\n", |
| 439 | + " model_name=job_name\n", |
411 | 440 | " model_data = info['ModelArtifacts']['S3ModelArtifacts']\n",
|
412 |
| - " \n", |
| 441 | + "\n", |
| 442 | + "print(model_name)\n", |
413 | 443 | "print(model_data)\n",
|
414 | 444 | "\n",
|
415 | 445 | "primary_container = {\n",
|
|
438 | 468 | {
|
439 | 469 | "cell_type": "code",
|
440 | 470 | "execution_count": null,
|
441 |
| - "metadata": { |
442 |
| - "collapsed": true |
443 |
| - }, |
| 471 | + "metadata": {}, |
444 | 472 | "outputs": [],
|
445 | 473 | "source": [
|
446 | 474 | "from time import gmtime, strftime\n",
|
|
469 | 497 | {
|
470 | 498 | "cell_type": "code",
|
471 | 499 | "execution_count": null,
|
472 |
| - "metadata": { |
473 |
| - "collapsed": true |
474 |
| - }, |
| 500 | + "metadata": {}, |
475 | 501 | "outputs": [],
|
476 | 502 | "source": [
|
477 | 503 | "%%time\n",
|
|
547 | 573 | {
|
548 | 574 | "cell_type": "code",
|
549 | 575 | "execution_count": null,
|
550 |
| - "metadata": { |
551 |
| - "collapsed": true |
552 |
| - }, |
| 576 | + "metadata": {}, |
553 | 577 | "outputs": [],
|
554 | 578 | "source": [
|
555 | 579 | "sentences = [\"you are so good !\",\n",
|
556 | 580 | " \"can you drive a car ?\",\n",
|
| 581 | + " \"i want to watch a movie .\"\n", |
557 | 582 | " ]\n",
|
558 | 583 | "\n",
|
559 | 584 | "payload = {\"instances\" : []}\n",
|
|
586 | 611 | {
|
587 | 612 | "cell_type": "code",
|
588 | 613 | "execution_count": null,
|
589 |
| - "metadata": { |
590 |
| - "collapsed": true |
591 |
| - }, |
| 614 | + "metadata": {}, |
592 | 615 | "outputs": [],
|
593 | 616 | "source": [
|
594 | 617 | "sentence = 'can you drive a car ?'\n",
|
|
639 | 662 | {
|
640 | 663 | "cell_type": "code",
|
641 | 664 | "execution_count": null,
|
642 |
| - "metadata": { |
643 |
| - "collapsed": true |
644 |
| - }, |
| 665 | + "metadata": {}, |
645 | 666 | "outputs": [],
|
646 | 667 | "source": [
|
647 | 668 | "plot_matrix(attention_matrix, target, source)"
|
|
666 | 687 | {
|
667 | 688 | "cell_type": "code",
|
668 | 689 | "execution_count": null,
|
669 |
| - "metadata": { |
670 |
| - "collapsed": true |
671 |
| - }, |
| 690 | + "metadata": {}, |
672 | 691 | "outputs": [],
|
673 | 692 | "source": [
|
674 | 693 | "import io\n",
|
|
772 | 791 | {
|
773 | 792 | "cell_type": "code",
|
774 | 793 | "execution_count": null,
|
775 |
| - "metadata": { |
776 |
| - "collapsed": true |
777 |
| - }, |
| 794 | + "metadata": {}, |
778 | 795 | "outputs": [],
|
779 | 796 | "source": [
|
780 | 797 | "targets = _parse_proto_response(response)\n",
|
|
795 | 812 | {
|
796 | 813 | "cell_type": "code",
|
797 | 814 | "execution_count": null,
|
798 |
| - "metadata": { |
799 |
| - "collapsed": true |
800 |
| - }, |
| 815 | + "metadata": {}, |
801 | 816 | "outputs": [],
|
802 | 817 | "source": [
|
803 | 818 | "# sage.delete_endpoint(EndpointName=endpoint_name)"
|
|
0 commit comments