|
11 | 11 | "cell_type": "markdown",
|
12 | 12 | "metadata": {},
|
13 | 13 | "source": [
|
14 |
| - "Amazon SageMaker Neo is API to compile machine learning models to optimize them for our choice of hardward targets. Currently, Neo supports pre-trained PyTorch models from [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html). General support for other PyTorch models is forthcoming." |
| 14 | + "Amazon SageMaker Neo is an API to compile machine learning models to optimize them for our choice of hardward targets. Currently, Neo supports pre-trained PyTorch models from [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html). General support for other PyTorch models is forthcoming." |
15 | 15 | ]
|
16 | 16 | },
|
17 | 17 | {
|
|
20 | 20 | "metadata": {},
|
21 | 21 | "outputs": [],
|
22 | 22 | "source": [
|
23 |
| - "!~/anaconda3/envs/pytorch_p36/bin/pip install torch==1.2.0 torchvision==0.4.0" |
| 23 | + "!~/anaconda3/envs/pytorch_p36/bin/pip install torch==1.4.0 torchvision==0.5.0" |
| 24 | + ] |
| 25 | + }, |
| 26 | + { |
| 27 | + "cell_type": "code", |
| 28 | + "execution_count": null, |
| 29 | + "metadata": {}, |
| 30 | + "outputs": [], |
| 31 | + "source": [ |
| 32 | + "!~/anaconda3/envs/pytorch_p36/bin/pip install --upgrade sagemaker" |
24 | 33 | ]
|
25 | 34 | },
|
26 | 35 | {
|
|
34 | 43 | "cell_type": "markdown",
|
35 | 44 | "metadata": {},
|
36 | 45 | "source": [
|
37 |
| - "We'll import [ResNet18](https://arxiv.org/abs/1512.03385) model from TorchVision and create a model artifact `model.tar.gz`:" |
| 46 | + "We'll import [ResNet18](https://arxiv.org/abs/1512.03385) model from TorchVision and create a model artifact `model.tar.gz`." |
38 | 47 | ]
|
39 | 48 | },
|
40 | 49 | {
|
|
60 | 69 | "cell_type": "markdown",
|
61 | 70 | "metadata": {},
|
62 | 71 | "source": [
|
63 |
| - "## Invoke Neo Compilation API" |
64 |
| - ] |
65 |
| - }, |
66 |
| - { |
67 |
| - "cell_type": "markdown", |
68 |
| - "metadata": {}, |
69 |
| - "source": [ |
70 |
| - "We then forward the model artifact to Neo Compilation API:" |
| 72 | + "### Upload the model archive to S3" |
71 | 73 | ]
|
72 | 74 | },
|
73 | 75 | {
|
|
87 | 89 | "bucket = sess.default_bucket()\n",
|
88 | 90 | "\n",
|
89 | 91 | "compilation_job_name = name_from_base('TorchVision-ResNet18-Neo')\n",
|
| 92 | + "prefix = compilation_job_name+'/model'\n", |
90 | 93 | "\n",
|
91 |
| - "model_key = '{}/model/model.tar.gz'.format(compilation_job_name)\n", |
92 |
| - "model_path = 's3://{}/{}'.format(bucket, model_key)\n", |
93 |
| - "boto3.resource('s3').Bucket(bucket).upload_file('model.tar.gz', model_key)\n", |
| 94 | + "model_path = sess.upload_data(path='model.tar.gz', key_prefix=prefix)\n", |
94 | 95 | "\n",
|
95 |
| - "sm_client = boto3.client('sagemaker')\n", |
96 | 96 | "data_shape = '{\"input0\":[1,3,224,224]}'\n",
|
97 | 97 | "target_device = 'ml_c5'\n",
|
98 | 98 | "framework = 'PYTORCH'\n",
|
99 |
| - "framework_version = '1.2.0'\n", |
| 99 | + "framework_version = '1.4.0'\n", |
100 | 100 | "compiled_model_path = 's3://{}/{}/output'.format(bucket, compilation_job_name)"
|
101 | 101 | ]
|
102 | 102 | },
|
103 |
| - { |
104 |
| - "cell_type": "code", |
105 |
| - "execution_count": null, |
106 |
| - "metadata": {}, |
107 |
| - "outputs": [], |
108 |
| - "source": [ |
109 |
| - "response = sm_client.create_compilation_job(\n", |
110 |
| - " CompilationJobName=compilation_job_name,\n", |
111 |
| - " RoleArn=role,\n", |
112 |
| - " InputConfig={\n", |
113 |
| - " 'S3Uri': model_path,\n", |
114 |
| - " 'DataInputConfig': data_shape,\n", |
115 |
| - " 'Framework': framework\n", |
116 |
| - " },\n", |
117 |
| - " OutputConfig={\n", |
118 |
| - " 'S3OutputLocation': compiled_model_path,\n", |
119 |
| - " 'TargetDevice': target_device\n", |
120 |
| - " },\n", |
121 |
| - " StoppingCondition={\n", |
122 |
| - " 'MaxRuntimeInSeconds': 300\n", |
123 |
| - " }\n", |
124 |
| - ")\n", |
125 |
| - "print(response)\n", |
126 |
| - "\n", |
127 |
| - "# Poll every 30 sec\n", |
128 |
| - "while True:\n", |
129 |
| - " response = sm_client.describe_compilation_job(CompilationJobName=compilation_job_name)\n", |
130 |
| - " if response['CompilationJobStatus'] == 'COMPLETED':\n", |
131 |
| - " break\n", |
132 |
| - " elif response['CompilationJobStatus'] == 'FAILED':\n", |
133 |
| - " raise RuntimeError('Compilation failed')\n", |
134 |
| - " print('Compiling ...')\n", |
135 |
| - " time.sleep(30)\n", |
136 |
| - "print('Done!')\n", |
137 |
| - "\n", |
138 |
| - "# Extract compiled model artifact\n", |
139 |
| - "compiled_model_path = response['ModelArtifacts']['S3ModelArtifacts']" |
140 |
| - ] |
141 |
| - }, |
142 |
| - { |
143 |
| - "cell_type": "markdown", |
144 |
| - "metadata": {}, |
145 |
| - "source": [ |
146 |
| - "## Create prediction endpoint" |
147 |
| - ] |
148 |
| - }, |
149 |
| - { |
150 |
| - "cell_type": "markdown", |
151 |
| - "metadata": {}, |
152 |
| - "source": [ |
153 |
| - "To create a prediction endpoint, we first specify two additional functions, to be used with Neo Deep Learning Runtime:\n", |
154 |
| - "\n", |
155 |
| - "* `neo_preprocess(payload, content_type)`: Function that takes in the payload and Content-Type of each incoming request and returns a NumPy array. Here, the payload is byte-encoded NumPy array, so the function simply decodes the bytes to obtain the NumPy array.\n", |
156 |
| - "* `neo_postprocess(result)`: Function that takes the prediction results produced by Deep Learining Runtime and returns the response body" |
157 |
| - ] |
158 |
| - }, |
159 |
| - { |
160 |
| - "cell_type": "code", |
161 |
| - "execution_count": null, |
162 |
| - "metadata": {}, |
163 |
| - "outputs": [], |
164 |
| - "source": [ |
165 |
| - "!pygmentize resnet18.py" |
166 |
| - ] |
167 |
| - }, |
168 | 103 | {
|
169 | 104 | "cell_type": "markdown",
|
170 | 105 | "metadata": {},
|
171 | 106 | "source": [
|
172 |
| - "Upload the Python script containing the two functions to S3:" |
173 |
| - ] |
174 |
| - }, |
175 |
| - { |
176 |
| - "cell_type": "code", |
177 |
| - "execution_count": null, |
178 |
| - "metadata": {}, |
179 |
| - "outputs": [], |
180 |
| - "source": [ |
181 |
| - "source_key = '{}/source/sourcedir.tar.gz'.format(compilation_job_name)\n", |
182 |
| - "source_path = 's3://{}/{}'.format(bucket, source_key)\n", |
183 |
| - "\n", |
184 |
| - "with tarfile.open('sourcedir.tar.gz', 'w:gz') as f:\n", |
185 |
| - " f.add('resnet18.py')\n", |
186 |
| - "\n", |
187 |
| - "boto3.resource('s3').Bucket(bucket).upload_file('sourcedir.tar.gz', source_key)" |
| 107 | + "## Invoke Neo Compilation API" |
188 | 108 | ]
|
189 | 109 | },
|
190 | 110 | {
|
191 | 111 | "cell_type": "markdown",
|
192 | 112 | "metadata": {},
|
193 | 113 | "source": [
|
194 |
| - "We then create a SageMaker model record:" |
| 114 | + "### Create a PyTorch SageMaker model" |
195 | 115 | ]
|
196 | 116 | },
|
197 | 117 | {
|
|
200 | 120 | "metadata": {},
|
201 | 121 | "outputs": [],
|
202 | 122 | "source": [
|
203 |
| - "from sagemaker.model import NEO_IMAGE_ACCOUNT\n", |
204 |
| - "from sagemaker.fw_utils import create_image_uri\n", |
205 |
| - "\n", |
206 |
| - "model_name = name_from_base('TorchVision-ResNet18-Neo')\n", |
| 123 | + "from sagemaker.pytorch.model import PyTorchModel\n", |
| 124 | + "from sagemaker.predictor import Predictor\n", |
207 | 125 | "\n",
|
208 |
| - "image_uri = create_image_uri(region, 'neo-' + framework.lower(), target_device.replace('_', '.'),\n", |
209 |
| - " framework_version, py_version='py3', account=NEO_IMAGE_ACCOUNT[region])\n", |
210 |
| - "\n", |
211 |
| - "response = sm_client.create_model(\n", |
212 |
| - " ModelName=model_name,\n", |
213 |
| - " PrimaryContainer={\n", |
214 |
| - " 'Image': image_uri,\n", |
215 |
| - " 'ModelDataUrl': compiled_model_path,\n", |
216 |
| - " 'Environment': { 'SAGEMAKER_SUBMIT_DIRECTORY': source_path }\n", |
217 |
| - " },\n", |
218 |
| - " ExecutionRoleArn=role\n", |
219 |
| - ")\n", |
220 |
| - "print(response)" |
| 126 | + "sagemaker_model = PyTorchModel(model_data=model_path,\n", |
| 127 | + " predictor_cls=Predictor,\n", |
| 128 | + " framework_version = framework_version,\n", |
| 129 | + " role=role,\n", |
| 130 | + " sagemaker_session=sess,\n", |
| 131 | + " entry_point='resnet18.py',\n", |
| 132 | + " source_dir='code',\n", |
| 133 | + " py_version='py3',\n", |
| 134 | + " env={'MMS_DEFAULT_RESPONSE_TIMEOUT': '500'}\n", |
| 135 | + " )" |
221 | 136 | ]
|
222 | 137 | },
|
223 | 138 | {
|
224 | 139 | "cell_type": "markdown",
|
225 | 140 | "metadata": {},
|
226 | 141 | "source": [
|
227 |
| - "Then we create an Endpoint Configuration:" |
| 142 | + "### Use Neo compiler to compile the model" |
228 | 143 | ]
|
229 | 144 | },
|
230 | 145 | {
|
|
233 | 148 | "metadata": {},
|
234 | 149 | "outputs": [],
|
235 | 150 | "source": [
|
236 |
| - "config_name = model_name\n", |
237 |
| - "\n", |
238 |
| - "response = sm_client.create_endpoint_config(\n", |
239 |
| - " EndpointConfigName=config_name,\n", |
240 |
| - " ProductionVariants=[\n", |
241 |
| - " {\n", |
242 |
| - " 'VariantName': 'default-variant-name',\n", |
243 |
| - " 'ModelName': model_name,\n", |
244 |
| - " 'InitialInstanceCount': 1,\n", |
245 |
| - " 'InstanceType': 'ml.c5.xlarge',\n", |
246 |
| - " 'InitialVariantWeight': 1.0\n", |
247 |
| - " },\n", |
248 |
| - " ],\n", |
249 |
| - ")\n", |
250 |
| - "print(response)" |
| 151 | + "compiled_model = sagemaker_model.compile(target_instance_family=target_device, \n", |
| 152 | + " input_shape=data_shape,\n", |
| 153 | + " job_name=compilation_job_name,\n", |
| 154 | + " role=role,\n", |
| 155 | + " framework=framework.lower(),\n", |
| 156 | + " framework_version=framework_version,\n", |
| 157 | + " output_path=compiled_model_path\n", |
| 158 | + " )" |
251 | 159 | ]
|
252 | 160 | },
|
253 | 161 | {
|
254 | 162 | "cell_type": "markdown",
|
255 | 163 | "metadata": {},
|
256 | 164 | "source": [
|
257 |
| - "Finally, we create an Endpoint:" |
| 165 | + "## Deploy the model" |
258 | 166 | ]
|
259 | 167 | },
|
260 | 168 | {
|
|
263 | 171 | "metadata": {},
|
264 | 172 | "outputs": [],
|
265 | 173 | "source": [
|
266 |
| - "endpoint_name = model_name + '-Endpoint'\n", |
267 |
| - "\n", |
268 |
| - "response = sm_client.create_endpoint(\n", |
269 |
| - " EndpointName=endpoint_name,\n", |
270 |
| - " EndpointConfigName=config_name,\n", |
271 |
| - ")\n", |
272 |
| - "print(response)\n", |
273 |
| - "\n", |
274 |
| - "print('Creating endpoint ...')\n", |
275 |
| - "waiter = sm_client.get_waiter('endpoint_in_service')\n", |
276 |
| - "waiter.wait(EndpointName=endpoint_name)\n", |
277 |
| - "\n", |
278 |
| - "response = sm_client.describe_endpoint(EndpointName=endpoint_name)\n", |
279 |
| - "print(response)" |
| 174 | + "predictor = compiled_model.deploy(initial_instance_count = 1,\n", |
| 175 | + " instance_type = 'ml.c5.9xlarge'\n", |
| 176 | + " )" |
280 | 177 | ]
|
281 | 178 | },
|
282 | 179 | {
|
|
301 | 198 | "metadata": {},
|
302 | 199 | "outputs": [],
|
303 | 200 | "source": [
|
304 |
| - "import json\n", |
305 | 201 | "import numpy as np\n",
|
306 |
| - "\n", |
307 |
| - "sm_runtime = boto3.Session().client('sagemaker-runtime')\n", |
| 202 | + "import json\n", |
308 | 203 | "\n",
|
309 | 204 | "with open('cat.jpg', 'rb') as f:\n",
|
310 | 205 | " payload = f.read()\n",
|
| 206 | + " payload = bytearray(payload) \n", |
311 | 207 | "\n",
|
312 |
| - "response = sm_runtime.invoke_endpoint(EndpointName=endpoint_name,\n", |
313 |
| - " ContentType='application/x-image',\n", |
314 |
| - " Body=payload)\n", |
315 |
| - "print(response)\n", |
316 |
| - "result = json.loads(response['Body'].read().decode())\n", |
| 208 | + "response = predictor.predict(payload)\n", |
| 209 | + "result = json.loads(response.decode())\n", |
317 | 210 | "print('Most likely class: {}'.format(np.argmax(result)))"
|
318 | 211 | ]
|
319 | 212 | },
|
|
346 | 239 | "metadata": {},
|
347 | 240 | "outputs": [],
|
348 | 241 | "source": [
|
349 |
| - "sess.delete_endpoint(endpoint_name)" |
| 242 | + "sess.delete_endpoint(predictor.endpoint_name)" |
350 | 243 | ]
|
351 | 244 | }
|
352 | 245 | ],
|
|
366 | 259 | "name": "python",
|
367 | 260 | "nbconvert_exporter": "python",
|
368 | 261 | "pygments_lexer": "ipython3",
|
369 |
| - "version": "3.6.5" |
| 262 | + "version": "3.6.10" |
370 | 263 | }
|
371 | 264 | },
|
372 | 265 | "nbformat": 4,
|
373 |
| - "nbformat_minor": 2 |
| 266 | + "nbformat_minor": 4 |
374 | 267 | }
|
0 commit comments