Skip to content

Commit db3e8b5

Browse files
authored
Merge pull request #402 from Aloha106/master
Add batch transform to image-classification notebook
2 parents b342e65 + 25ce707 commit db3e8b5

File tree

1 file changed

+161
-22
lines changed

1 file changed

+161
-22
lines changed

introduction_to_amazon_algorithms/imageclassification_caltech/Image-classification-lst-format.ipynb

Lines changed: 161 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"# Image classification training with image format demo\n",
7+
"# Image classification training with image format\n",
88
"\n",
99
"1. [Introduction](#Introduction)\n",
1010
"2. [Prerequisites and Preprocessing](#Prequisites-and-Preprocessing)\n",
@@ -13,11 +13,14 @@
1313
"3. [Fine-tuning The Image Classification Model](#Fine-tuning-the-Image-classification-model)\n",
1414
" 1. [Training parameters](#Training-parameters)\n",
1515
" 2. [Training](#Training)\n",
16-
"4. [Set Up Hosting For The Model](#Set-up-hosting-for-the-model)\n",
16+
"4. [Deploy The Model](#Deploy-the-model)\n",
1717
" 1. [Create model](#Create-model)\n",
18-
" 2. [Create endpoint configuration](#Create-endpoint-configuration)\n",
19-
" 3. [Create endpoint](#Create-endpoint)\n",
20-
" 4. [Perform inference](#Perform-inference)"
18+
" 2. [Batch transform](#Batch-transform)\n",
19+
" 3. [Realtime inference](#Realtime-inference)\n",
20+
" 1. [Create endpoint configuration](#Create-endpoint-configuration) \n",
21+
" 2. [Create endpoint](#Create-endpoint) \n",
22+
" 3. [Perform inference](#Perform-inference) \n",
23+
" 4. [Clean up](#Clean-up)"
2124
]
2225
},
2326
{
@@ -163,10 +166,10 @@
163166
"outputs": [],
164167
"source": [
165168
"# Four channels: train, validation, train_lst, and validation_lst\n",
166-
"s3train = 's3://{}/train/'.format(bucket)\n",
167-
"s3validation = 's3://{}/validation/'.format(bucket)\n",
168-
"s3train_lst = 's3://{}/train_lst/'.format(bucket)\n",
169-
"s3validation_lst = 's3://{}/validation_lst/'.format(bucket)\n",
169+
"s3train = 's3://{}/image-classification/train/'.format(bucket)\n",
170+
"s3validation = 's3://{}/image-classification/validation/'.format(bucket)\n",
171+
"s3train_lst = 's3://{}/image-classification/train_lst/'.format(bucket)\n",
172+
"s3validation_lst = 's3://{}/image-classification/validation_lst/'.format(bucket)\n",
170173
"\n",
171174
"# upload the image files to train and validation channels\n",
172175
"!aws s3 cp caltech_256_train_60 $s3train --recursive --quiet\n",
@@ -343,7 +346,7 @@
343346
" \"DataSource\": {\n",
344347
" \"S3DataSource\": {\n",
345348
" \"S3DataType\": \"S3Prefix\",\n",
346-
" \"S3Uri\": 's3://{}/train/'.format(bucket),\n",
349+
" \"S3Uri\": s3train,\n",
347350
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
348351
" }\n",
349352
" },\n",
@@ -355,7 +358,7 @@
355358
" \"DataSource\": {\n",
356359
" \"S3DataSource\": {\n",
357360
" \"S3DataType\": \"S3Prefix\",\n",
358-
" \"S3Uri\": 's3://{}/validation/'.format(bucket),\n",
361+
" \"S3Uri\": s3validation,\n",
359362
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
360363
" }\n",
361364
" },\n",
@@ -367,7 +370,7 @@
367370
" \"DataSource\": {\n",
368371
" \"S3DataSource\": {\n",
369372
" \"S3DataType\": \"S3Prefix\",\n",
370-
" \"S3Uri\": 's3://{}/train_lst/'.format(bucket),\n",
373+
" \"S3Uri\": s3train_lst,\n",
371374
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
372375
" }\n",
373376
" },\n",
@@ -379,7 +382,7 @@
379382
" \"DataSource\": {\n",
380383
" \"S3DataSource\": {\n",
381384
" \"S3DataType\": \"S3Prefix\",\n",
382-
" \"S3Uri\": 's3://{}/validation_lst/'.format(bucket),\n",
385+
" \"S3Uri\": s3validation_lst,\n",
383386
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
384387
" }\n",
385388
" },\n",
@@ -452,16 +455,15 @@
452455
"cell_type": "markdown",
453456
"metadata": {},
454457
"source": [
455-
"## Set Up Hosting For The Model\n",
458+
"## Deploy The Model\n",
456459
"\n",
457460
"A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means predicting the class label given an input image.\n",
458461
"\n",
459462
"This section involves several steps,\n",
460463
"\n",
461464
"1. [Create model](#CreateModel) - Create model for the training output\n",
462-
"1. [Create endpoint configuration](#CreateEndpointConfiguration) - Create a configuration defining an endpoint.\n",
463-
"1. [Create endpoint](#CreateEndpoint) - Use the configuration to create an inference endpoint.\n",
464-
"1. [Perform inference](#Perform Inference) - Perform inference on some input data using the endpoint."
465+
"1. [Batch Transform](#BatchTransform) - Create a transform job to perform batch inference.\n",
466+
"1. [Host the model for realtime inference](#HostTheModel) - Create an inference endpoint and perform realtime inference."
465467
]
466468
},
467469
{
@@ -513,7 +515,144 @@
513515
"cell_type": "markdown",
514516
"metadata": {},
515517
"source": [
516-
"### Create endpoint configuration\n",
518+
"### Batch transform\n",
519+
"\n",
520+
"We now create a SageMaker Batch Transform job using the model created above to perform batch prediction."
521+
]
522+
},
523+
{
524+
"cell_type": "code",
525+
"execution_count": null,
526+
"metadata": {
527+
"collapsed": true
528+
},
529+
"outputs": [],
530+
"source": [
531+
"timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
532+
"batch_job_name=\"image-classification-model\" + timestamp\n",
533+
"batch_input = s3validation + \"001.ak47/\"\n",
534+
"request = \\\n",
535+
"{\n",
536+
" \"TransformJobName\": batch_job_name,\n",
537+
" \"ModelName\": model_name,\n",
538+
" \"MaxConcurrentTransforms\": 16,\n",
539+
" \"MaxPayloadInMB\": 6,\n",
540+
" \"BatchStrategy\": \"SingleRecord\",\n",
541+
" \"TransformOutput\": {\n",
542+
" \"S3OutputPath\": 's3://{}/{}/output'.format(bucket, batch_job_name)\n",
543+
" },\n",
544+
" \"TransformInput\": {\n",
545+
" \"DataSource\": {\n",
546+
" \"S3DataSource\": {\n",
547+
" \"S3DataType\": \"S3Prefix\",\n",
548+
" \"S3Uri\": batch_input\n",
549+
" }\n",
550+
" },\n",
551+
" \"ContentType\": \"application/x-image\",\n",
552+
" \"SplitType\": \"None\",\n",
553+
" \"CompressionType\": \"None\"\n",
554+
" },\n",
555+
" \"TransformResources\": {\n",
556+
" \"InstanceType\": \"ml.p2.xlarge\",\n",
557+
" \"InstanceCount\": 1\n",
558+
" }\n",
559+
"}\n",
560+
"\n",
561+
"print('Transform job name: {}'.format(batch_job_name))\n",
562+
"print('\\nInput Data Location: {}'.format(batch_input))"
563+
]
564+
},
565+
{
566+
"cell_type": "code",
567+
"execution_count": null,
568+
"metadata": {
569+
"collapsed": true
570+
},
571+
"outputs": [],
572+
"source": [
573+
"sagemaker = boto3.client('sagemaker')\n",
574+
"sagemaker.create_transform_job(**request)\n",
575+
"\n",
576+
"print(\"Created Transform job with name: \", batch_job_name)\n",
577+
"\n",
578+
"while(True):\n",
579+
" response = sagemaker.describe_transform_job(TransformJobName=batch_job_name)\n",
580+
" status = response['TransformJobStatus']\n",
581+
" if status == 'Completed':\n",
582+
" print(\"Transform job ended with status: \" + status)\n",
583+
" break\n",
584+
" if status == 'Failed':\n",
585+
" message = response['FailureReason']\n",
586+
" print('Transform failed with the following error: {}'.format(message))\n",
587+
" raise Exception('Transform job failed') \n",
588+
" time.sleep(30) "
589+
]
590+
},
591+
{
592+
"cell_type": "markdown",
593+
"metadata": {},
594+
"source": [
595+
"After the job completes, let's check the prediction results."
596+
]
597+
},
598+
{
599+
"cell_type": "code",
600+
"execution_count": null,
601+
"metadata": {},
602+
"outputs": [],
603+
"source": [
604+
"from urllib.parse import urlparse\n",
605+
"import json\n",
606+
"import numpy as np\n",
607+
"\n",
608+
"s3_client = boto3.client('s3')\n",
609+
"object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter']\n",
610+
"\n",
611+
"def list_objects(s3_client, bucket, prefix):\n",
612+
" response = s3_client.list_objects(Bucket=bucket, Prefix=prefix)\n",
613+
" objects = [content['Key'] for content in response['Contents']]\n",
614+
" return objects\n",
615+
"\n",
616+
"def get_label(s3_client, bucket, prefix):\n",
617+
" filename = prefix.split('/')[-1]\n",
618+
" s3_client.download_file(bucket, prefix, filename)\n",
619+
" with open(filename) as f:\n",
620+
" data = json.load(f)\n",
621+
" index = np.argmax(data['prediction'])\n",
622+
" probability = data['prediction'][index]\n",
623+
" print(\"Result: label - \" + object_categories[index] + \", probability - \" + str(probability))\n",
624+
" return object_categories[index], probability\n",
625+
"\n",
626+
"inputs = list_objects(s3_client, bucket, urlparse(batch_input).path.lstrip('/'))\n",
627+
"print(\"Sample inputs: \" + str(inputs[:2]))\n",
628+
"\n",
629+
"outputs = list_objects(s3_client, bucket, batch_job_name + \"/output\")\n",
630+
"print(\"Sample output: \" + str(outputs[:2]))\n",
631+
"\n",
632+
"# Check prediction result of the first 2 images\n",
633+
"[get_label(s3_client, bucket, prefix) for prefix in outputs[0:2]]"
634+
]
635+
},
636+
{
637+
"cell_type": "markdown",
638+
"metadata": {},
639+
"source": [
640+
"### Realtime inference\n",
641+
"\n",
642+
"We now host the model with an endpoint and perform realtime inference.\n",
643+
"\n",
644+
"This section involves several steps,\n",
645+
"1. [Create endpoint configuration](#CreateEndpointConfiguration) - Create a configuration defining an endpoint.\n",
646+
"1. [Create endpoint](#CreateEndpoint) - Use the configuration to create an inference endpoint.\n",
647+
"1. [Perform inference](#PerformInference) - Perform inference on some input data using the endpoint.\n",
648+
"1. [Clean up](#CleanUp) - Delete the endpoint and model"
649+
]
650+
},
651+
{
652+
"cell_type": "markdown",
653+
"metadata": {},
654+
"source": [
655+
"#### Create endpoint configuration\n",
517656
"At launch, we will support configuring REST endpoints in hosting with multiple models, e.g. for A/B testing purposes. In order to support this, customers create an endpoint configuration, that describes the distribution of traffic across the models, whether split, shadowed, or sampled in some way.\n",
518657
"\n",
519658
"In addition, the endpoint configuration describes the instance type required for model deployment, and at launch will describe the autoscaling configuration."
@@ -547,8 +686,8 @@
547686
"cell_type": "markdown",
548687
"metadata": {},
549688
"source": [
550-
"### Create endpoint\n",
551-
"Lastly, the customer creates the endpoint that serves up the model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 9-11 minutes to complete."
689+
"#### Create endpoint\n",
690+
"Next, the customer creates the endpoint that serves up the model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 9-11 minutes to complete."
552691
]
553692
},
554693
{
@@ -625,7 +764,7 @@
625764
"cell_type": "markdown",
626765
"metadata": {},
627766
"source": [
628-
"### Perform inference\n",
767+
"#### Perform inference\n",
629768
"Finally, the customer can now validate the model for use. They can obtain the endpoint from the client library using the result from previous operations, and generate classifications from the trained model using that endpoint.\n"
630769
]
631770
},
@@ -645,7 +784,7 @@
645784
"cell_type": "markdown",
646785
"metadata": {},
647786
"source": [
648-
"#### Download test image"
787+
"##### Download test image"
649788
]
650789
},
651790
{

0 commit comments

Comments
 (0)