Skip to content

Commit 7c2974e

Browse files
authored
Merge pull request aws#92 from awslabs/seq2seq
Incorporate comments from notebook bash
2 parents b7a0b6d + 81d6b97 commit 7c2974e

File tree

1 file changed

+74
-59
lines changed

1 file changed

+74
-59
lines changed

introduction_to_amazon_algorithms/seq2seq_translation_en-de/SageMaker-Seq2Seq-Translation-English-German.ipynb

Lines changed: 74 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,25 @@
133133
"Please note that it is a common practise to split words into subwords using Byte Pair Encoding (BPE). Please refer to [this](https://github.com/awslabs/sockeye/tree/master/tutorials/wmt) tutorial if you are interested in performing BPE."
134134
]
135135
},
136+
{
137+
"cell_type": "markdown",
138+
"metadata": {},
139+
"source": [
140+
"Since training on the whole dataset might take several hours/days, for this demo, let us train on the **first 10,000 lines only**. Don't run the next cell if you want to train on the complete dataset."
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": null,
146+
"metadata": {
147+
"collapsed": true
148+
},
149+
"outputs": [],
150+
"source": [
151+
"!head -n 10000 corpus.tc.en > corpus.tc.en.small\n",
152+
"!head -n 10000 corpus.tc.de > corpus.tc.de.small"
153+
]
154+
},
136155
{
137156
"cell_type": "markdown",
138157
"metadata": {},
@@ -155,22 +174,20 @@
155174
"cell_type": "markdown",
156175
"metadata": {},
157176
"source": [
158-
"Let's do the preprocessing now. Sit back and relax as the script below might take around 10-15 min"
177+
"The cell below does the preprocessing. If you are using the complete dataset, the script might take around 10-15 min on an m4.xlarge notebook instance. Remove \".small\" from the file names for training on full datasets."
159178
]
160179
},
161180
{
162181
"cell_type": "code",
163182
"execution_count": null,
164-
"metadata": {
165-
"collapsed": true
166-
},
183+
"metadata": {},
167184
"outputs": [],
168185
"source": [
169186
"%%time\n",
170187
"%%bash\n",
171188
"python3 create_vocab_proto.py \\\n",
172-
" --train-source corpus.tc.en \\\n",
173-
" --train-target corpus.tc.de \\\n",
189+
" --train-source corpus.tc.en.small \\\n",
190+
" --train-target corpus.tc.de.small \\\n",
174191
" --val-source validation/newstest2014.tc.en \\\n",
175192
" --val-target validation/newstest2014.tc.de"
176193
]
@@ -222,9 +239,7 @@
222239
{
223240
"cell_type": "code",
224241
"execution_count": null,
225-
"metadata": {
226-
"collapsed": true
227-
},
242+
"metadata": {},
228243
"outputs": [],
229244
"source": [
230245
"containers = {'us-west-2': '433757028032.dkr.ecr.us-west-2.amazonaws.com/seq2seq:latest',\n",
@@ -245,12 +260,10 @@
245260
{
246261
"cell_type": "code",
247262
"execution_count": null,
248-
"metadata": {
249-
"collapsed": true
250-
},
263+
"metadata": {},
251264
"outputs": [],
252265
"source": [
253-
"job_name = 'seq2seq-en-de-small-p2-16x-' + strftime(\"%Y-%m-%d-%H\", gmtime())\n",
266+
"job_name = 'seq2seq-en-de-p2-xlarge-' + strftime(\"%Y-%m-%d-%H\", gmtime())\n",
254267
"print(\"Training job\", job_name)\n",
255268
"\n",
256269
"create_training_params = \\\n",
@@ -266,7 +279,7 @@
266279
" \"ResourceConfig\": {\n",
267280
" # Seq2Seq does not support multiple machines. Currently, it only supports single machine, multiple GPUs\n",
268281
" \"InstanceCount\": 1,\n",
269-
" \"InstanceType\": \"ml.p2.16xlarge\", # We suggest one of [\"ml.p2.16xlarge\", \"ml.p2.8xlarge\", \"ml.p2.xlarge\"]\n",
282+
" \"InstanceType\": \"ml.p2.xlarge\", # We suggest one of [\"ml.p2.16xlarge\", \"ml.p2.8xlarge\", \"ml.p2.xlarge\"]\n",
270283
" \"VolumeSizeInGB\": 50\n",
271284
" },\n",
272285
" \"TrainingJobName\": job_name,\n",
@@ -275,14 +288,17 @@
275288
" \"max_seq_len_source\": \"60\",\n",
276289
" \"max_seq_len_target\": \"60\",\n",
277290
" \"optimized_metric\": \"bleu\",\n",
278-
" \"batch_size\": \"256\",\n",
291+
" \"batch_size\": \"64\", # Please use a larger batch size (256 or 512) if using ml.p2.8xlarge or ml.p2.16xlarge\n",
279292
" \"checkpoint_frequency_num_batches\": \"1000\",\n",
280293
" \"rnn_num_hidden\": \"512\",\n",
281294
" \"num_layers_encoder\": \"1\",\n",
282295
" \"num_layers_decoder\": \"1\",\n",
283296
" \"num_embed_source\": \"512\",\n",
284297
" \"num_embed_target\": \"512\",\n",
285298
" \"checkpoint_threshold\": \"3\",\n",
299+
" \"max_num_batches\": \"2100\"\n",
300+
" # Training will stop after 2100 iterations/batches.\n",
301+
" # This is just for demo purposes. Remove the above parameter if you want a better model.\n",
286302
" },\n",
287303
" \"StoppingCondition\": {\n",
288304
" \"MaxRuntimeInSeconds\": 48 * 3600\n",
@@ -331,9 +347,7 @@
331347
{
332348
"cell_type": "code",
333349
"execution_count": null,
334-
"metadata": {
335-
"collapsed": true
336-
},
350+
"metadata": {},
337351
"outputs": [],
338352
"source": [
339353
"status = sagemaker_client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
@@ -352,6 +366,13 @@
352366
"> Now wait for the training job to complete and proceed to the next step after you see model artifacts in your S3 bucket."
353367
]
354368
},
369+
{
370+
"cell_type": "markdown",
371+
"metadata": {},
372+
"source": [
373+
"You can jump to [Use a pretrained model](#Use-a-pretrained-model) as training might take some time."
374+
]
375+
},
355376
{
356377
"cell_type": "markdown",
357378
"metadata": {},
@@ -366,9 +387,7 @@
366387
"- Perform Inference - Perform inference on some input data using the endpoint.\n",
367388
"\n",
368389
"### Create model\n",
369-
"We now create a SageMaker Model from the training output. Using the model, we can then create an Endpoint Configuration.\n",
370-
"\n",
371-
"#### Note: Please uncomment and run the lines below if you want to use a pretrained model, as training might take several hours/days to complete."
390+
"We now create a SageMaker Model from the training output. Using the model, we can then create an Endpoint Configuration."
372391
]
373392
},
374393
{
@@ -379,37 +398,48 @@
379398
},
380399
"outputs": [],
381400
"source": [
401+
"use_pretrained_model = False"
402+
]
403+
},
404+
{
405+
"cell_type": "markdown",
406+
"metadata": {},
407+
"source": [
408+
"### Use a pretrained model\n",
409+
"#### Please uncomment and run the cell below if you want to use a pretrained model, as training might take several hours/days to complete."
410+
]
411+
},
412+
{
413+
"cell_type": "code",
414+
"execution_count": null,
415+
"metadata": {},
416+
"outputs": [],
417+
"source": [
418+
"# use_pretrained_model = True\n",
419+
"# model_name = \"pretrained-en-de-model\"\n",
382420
"# !curl https://s3-us-west-2.amazonaws.com/gsaur-seq2seq-data/seq2seq/eng-german/full-nb-translation-eng-german-p2-16x-2017-11-24-22-25-53/output/model.tar.gz > model.tar.gz\n",
383421
"# !curl https://s3-us-west-2.amazonaws.com/gsaur-seq2seq-data/seq2seq/eng-german/full-nb-translation-eng-german-p2-16x-2017-11-24-22-25-53/output/vocab.src.json > vocab.src.json\n",
384422
"# !curl https://s3-us-west-2.amazonaws.com/gsaur-seq2seq-data/seq2seq/eng-german/full-nb-translation-eng-german-p2-16x-2017-11-24-22-25-53/output/vocab.trg.json > vocab.trg.json\n",
385423
"# upload_to_s3(bucket, prefix, 'pretrained_model', 'model.tar.gz')\n",
386-
"# use_pretrained_model = True\n",
387424
"# model_data = \"s3://{}/{}/pretrained_model/model.tar.gz\".format(bucket, prefix)"
388425
]
389426
},
390427
{
391428
"cell_type": "code",
392429
"execution_count": null,
393-
"metadata": {
394-
"collapsed": true
395-
},
430+
"metadata": {},
396431
"outputs": [],
397432
"source": [
398433
"%%time\n",
399434
"\n",
400435
"sage = boto3.client('sagemaker')\n",
401436
"\n",
402-
"model_name=job_name\n",
403-
"print(model_name)\n",
404-
"\n",
405-
"info = sage.describe_training_job(TrainingJobName=job_name)\n",
406-
"\n",
407-
"try:\n",
408-
" if use_pretrained_model:\n",
409-
" model_data\n",
410-
"except:\n",
437+
"if not use_pretrained_model:\n",
438+
" info = sage.describe_training_job(TrainingJobName=job_name)\n",
439+
" model_name=job_name\n",
411440
" model_data = info['ModelArtifacts']['S3ModelArtifacts']\n",
412-
" \n",
441+
"\n",
442+
"print(model_name)\n",
413443
"print(model_data)\n",
414444
"\n",
415445
"primary_container = {\n",
@@ -438,9 +468,7 @@
438468
{
439469
"cell_type": "code",
440470
"execution_count": null,
441-
"metadata": {
442-
"collapsed": true
443-
},
471+
"metadata": {},
444472
"outputs": [],
445473
"source": [
446474
"from time import gmtime, strftime\n",
@@ -469,9 +497,7 @@
469497
{
470498
"cell_type": "code",
471499
"execution_count": null,
472-
"metadata": {
473-
"collapsed": true
474-
},
500+
"metadata": {},
475501
"outputs": [],
476502
"source": [
477503
"%%time\n",
@@ -547,13 +573,12 @@
547573
{
548574
"cell_type": "code",
549575
"execution_count": null,
550-
"metadata": {
551-
"collapsed": true
552-
},
576+
"metadata": {},
553577
"outputs": [],
554578
"source": [
555579
"sentences = [\"you are so good !\",\n",
556580
" \"can you drive a car ?\",\n",
581+
" \"i want to watch a movie .\"\n",
557582
" ]\n",
558583
"\n",
559584
"payload = {\"instances\" : []}\n",
@@ -586,9 +611,7 @@
586611
{
587612
"cell_type": "code",
588613
"execution_count": null,
589-
"metadata": {
590-
"collapsed": true
591-
},
614+
"metadata": {},
592615
"outputs": [],
593616
"source": [
594617
"sentence = 'can you drive a car ?'\n",
@@ -639,9 +662,7 @@
639662
{
640663
"cell_type": "code",
641664
"execution_count": null,
642-
"metadata": {
643-
"collapsed": true
644-
},
665+
"metadata": {},
645666
"outputs": [],
646667
"source": [
647668
"plot_matrix(attention_matrix, target, source)"
@@ -666,9 +687,7 @@
666687
{
667688
"cell_type": "code",
668689
"execution_count": null,
669-
"metadata": {
670-
"collapsed": true
671-
},
690+
"metadata": {},
672691
"outputs": [],
673692
"source": [
674693
"import io\n",
@@ -772,9 +791,7 @@
772791
{
773792
"cell_type": "code",
774793
"execution_count": null,
775-
"metadata": {
776-
"collapsed": true
777-
},
794+
"metadata": {},
778795
"outputs": [],
779796
"source": [
780797
"targets = _parse_proto_response(response)\n",
@@ -795,9 +812,7 @@
795812
{
796813
"cell_type": "code",
797814
"execution_count": null,
798-
"metadata": {
799-
"collapsed": true
800-
},
815+
"metadata": {},
801816
"outputs": [],
802817
"source": [
803818
"# sage.delete_endpoint(EndpointName=endpoint_name)"

0 commit comments

Comments
 (0)