|
93 | 93 | "import sagemaker\n",
|
94 | 94 | "import time\n",
|
95 | 95 | "from sagemaker.utils import name_from_base\n",
|
| 96 | + "from sagemaker import image_uris\n", |
96 | 97 | "\n",
|
97 | 98 | "role = sagemaker.get_execution_role()\n",
|
98 | 99 | "sess = sagemaker.Session()\n",
|
|
106 | 107 | "\n",
|
107 | 108 | "data_shape = '{\"input0\":[1,3,224,224]}'\n",
|
108 | 109 | "target_device = 'ml_c5'\n",
|
109 |
| - "framework = 'PYTORCH'\n", |
110 |
| - "framework_version = '1.2.0'\n", |
111 |
| - "compiled_model_path = 's3://{}/{}/output'.format(bucket, compilation_job_name)" |
| 110 | + "framework = 'pytorch'\n", |
| 111 | + "framework_version = '1.4.0'\n", |
| 112 | + "compiled_model_path = 's3://{}/{}/output'.format(bucket, compilation_job_name)\n", |
| 113 | + "\n", |
| 114 | + "inference_image_uri = image_uris.retrieve(f'neo-{framework}', region, framework_version, instance_type=target_device)" |
112 | 115 | ]
|
113 | 116 | },
|
114 | 117 | {
|
|
125 | 128 | "outputs": [],
|
126 | 129 | "source": [
|
127 | 130 | "from sagemaker.pytorch.model import PyTorchModel\n",
|
| 131 | + "from sagemaker.predictor import Predictor\n", |
128 | 132 | "\n",
|
129 | 133 | "pt_vgg = PyTorchModel(model_data=model_path,\n",
|
130 | 134 | " framework_version=framework_version,\n",
|
131 |
| - " role=role, \n", |
132 |
| - " entry_point='vgg19_bn_old.py',\n", |
| 135 | + " predictor_cls=Predictor,\n", |
| 136 | + " role=role, \n", |
133 | 137 | " sagemaker_session=sess,\n",
|
134 |
| - " py_version='py3'\n", |
| 138 | + " entry_point='vgg19_bn_uncompiled.py',\n", |
| 139 | + " source_dir='code',\n", |
| 140 | + " py_version='py3',\n", |
| 141 | + " image_uri=inference_image_uri\n", |
135 | 142 | " )"
|
136 | 143 | ]
|
137 | 144 | },
|
|
176 | 183 | "cell_type": "markdown",
|
177 | 184 | "metadata": {},
|
178 | 185 | "source": [
|
179 |
| - "#### Image Pre-processing" |
| 186 | + "#### Read the image payload" |
180 | 187 | ]
|
181 | 188 | },
|
182 | 189 | {
|
|
185 | 192 | "metadata": {},
|
186 | 193 | "outputs": [],
|
187 | 194 | "source": [
|
188 |
| - "import torch\n", |
189 |
| - "from PIL import Image\n", |
190 |
| - "from torchvision import transforms\n", |
191 |
| - "import numpy as np\n", |
192 |
| - "input_image = Image.open('cat.jpg')\n", |
193 |
| - "preprocess = transforms.Compose([\n", |
194 |
| - " transforms.Resize(256),\n", |
195 |
| - " transforms.CenterCrop(224),\n", |
196 |
| - " transforms.ToTensor(),\n", |
197 |
| - " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", |
198 |
| - "])\n", |
199 |
| - "input_tensor = preprocess(input_image)\n", |
200 |
| - "input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model" |
| 195 | + "import json\n", |
| 196 | + "\n", |
| 197 | + "with open('cat.jpg', 'rb') as f:\n", |
| 198 | + " payload = f.read()\n", |
| 199 | + " payload = bytearray(payload) " |
201 | 200 | ]
|
202 | 201 | },
|
203 | 202 | {
|
|
216 | 215 | "import time\n",
|
217 | 216 | "start = time.time()\n",
|
218 | 217 | "for _ in range(1000):\n",
|
219 |
| - " output = vgg_predictor.predict(input_batch)\n", |
| 218 | + " output = vgg_predictor.predict(payload)\n", |
220 | 219 | "inference_time = (time.time()-start)\n",
|
221 | 220 | "print('Inference time is ' + str(inference_time) + 'millisecond')"
|
222 | 221 | ]
|
|
227 | 226 | "metadata": {},
|
228 | 227 | "outputs": [],
|
229 | 228 | "source": [
|
230 |
| - "_, predicted = torch.max(torch.from_numpy(np.array(output)), 1)" |
| 229 | + "import numpy as np\n", |
| 230 | + "result = json.loads(output.decode())\n", |
| 231 | + "predicted = np.argmax(result)" |
231 | 232 | ]
|
232 | 233 | },
|
233 | 234 | {
|
|
250 | 251 | "metadata": {},
|
251 | 252 | "outputs": [],
|
252 | 253 | "source": [
|
253 |
| - "print(\"Result: label - \" + object_categories[str(predicted.item())])" |
| 254 | + "print(\"Result: label - \" + object_categories[str(predicted)])" |
254 | 255 | ]
|
255 | 256 | },
|
256 | 257 | {
|
|
277 | 278 | "## Neo optimization"
|
278 | 279 | ]
|
279 | 280 | },
|
280 |
| - { |
281 |
| - "cell_type": "markdown", |
282 |
| - "metadata": {}, |
283 |
| - "source": [ |
284 |
| - "### Update framework version" |
285 |
| - ] |
286 |
| - }, |
287 |
| - { |
288 |
| - "cell_type": "code", |
289 |
| - "execution_count": null, |
290 |
| - "metadata": {}, |
291 |
| - "outputs": [], |
292 |
| - "source": [ |
293 |
| - "framework_version = '1.4.0'" |
294 |
| - ] |
295 |
| - }, |
296 |
| - { |
297 |
| - "cell_type": "markdown", |
298 |
| - "metadata": {}, |
299 |
| - "source": [ |
300 |
| - "### Re-create the model archive" |
301 |
| - ] |
302 |
| - }, |
303 |
| - { |
304 |
| - "cell_type": "code", |
305 |
| - "execution_count": null, |
306 |
| - "metadata": {}, |
307 |
| - "outputs": [], |
308 |
| - "source": [ |
309 |
| - "with tarfile.open('model.tar.gz', 'w:gz') as f:\n", |
310 |
| - " f.add('model.pth')" |
311 |
| - ] |
312 |
| - }, |
313 | 281 | {
|
314 | 282 | "cell_type": "markdown",
|
315 | 283 | "metadata": {},
|
|
331 | 299 | " framework_version = framework_version,\n",
|
332 | 300 | " role=role,\n",
|
333 | 301 | " sagemaker_session=sess,\n",
|
334 |
| - " entry_point='vgg19_bn.py',\n", |
| 302 | + " entry_point='vgg19_bn_compiled.py',\n", |
335 | 303 | " source_dir='code',\n",
|
336 | 304 | " py_version='py3',\n",
|
337 | 305 | " env={'MMS_DEFAULT_RESPONSE_TIMEOUT': '500'}\n",
|
|
361 | 329 | " )"
|
362 | 330 | ]
|
363 | 331 | },
|
364 |
| - { |
365 |
| - "cell_type": "code", |
366 |
| - "execution_count": null, |
367 |
| - "metadata": {}, |
368 |
| - "outputs": [], |
369 |
| - "source": [ |
370 |
| - "# TODO(kkoppolu): Delete after new SDK version sets the image URI correctly\n", |
371 |
| - "compiled_model.image_uri = compiled_model.image_uri.replace(\"neo\", \"inference\")" |
372 |
| - ] |
373 |
| - }, |
374 | 332 | {
|
375 | 333 | "cell_type": "code",
|
376 | 334 | "execution_count": null,
|
|
382 | 340 | " )"
|
383 | 341 | ]
|
384 | 342 | },
|
385 |
| - { |
386 |
| - "cell_type": "code", |
387 |
| - "execution_count": null, |
388 |
| - "metadata": {}, |
389 |
| - "outputs": [], |
390 |
| - "source": [ |
391 |
| - "import json\n", |
392 |
| - "\n", |
393 |
| - "with open('cat.jpg', 'rb') as f:\n", |
394 |
| - " payload = f.read()\n", |
395 |
| - " payload = bytearray(payload) " |
396 |
| - ] |
397 |
| - }, |
398 | 343 | {
|
399 | 344 | "cell_type": "markdown",
|
400 | 345 | "metadata": {},
|
|
435 | 380 | "source": [
|
436 | 381 | "sess.delete_endpoint(predictor.endpoint_name)"
|
437 | 382 | ]
|
438 |
| - }, |
439 |
| - { |
440 |
| - "cell_type": "code", |
441 |
| - "execution_count": null, |
442 |
| - "metadata": {}, |
443 |
| - "outputs": [], |
444 |
| - "source": [] |
445 | 383 | }
|
446 | 384 | ],
|
447 | 385 | "metadata": {
|
|
0 commit comments