Skip to content

Commit af8a848

Browse files
authored
Merge pull request aws#96 from awslabs/xgboost-byom
Change content type to csv and other small changes
2 parents 62ff05b + 113c083 commit af8a848

File tree

1 file changed

+34
-17
lines changed

1 file changed

+34
-17
lines changed

advanced_functionality/xgboost_bring_your_own_model/xgboost_bring_your_own_model.ipynb

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"---\n",
2727
"## Background\n",
2828
"\n",
29-
"Amazon SageMaker includes functionality to support a hosted notebook environment, distributed, managed training, and real-time hosting. We think it works best when all three of these services are used together, but they can also be used independently. Some use cases may only require hosting. Maybe the model was trained prior to Amazon SageMaker existing, in a different service.\n",
29+
"Amazon SageMaker includes functionality to support a hosted notebook environment, distributed, serverless training, and real-time hosting. We think it works best when all three of these services are used together, but they can also be used independently. Some use cases may only require hosting. Maybe the model was trained prior to Amazon SageMaker existing, in a different service.\n",
3030
"\n",
3131
"This notebook shows how to use a pre-existing scikit-learn model with the Amazon SageMaker XGBoost Algorithm container to quickly create a hosted endpoint for that model.\n",
3232
"\n",
@@ -53,6 +53,7 @@
5353
"import os\n",
5454
"import boto3\n",
5555
"import re\n",
56+
"import json\n",
5657
"\n",
5758
"region = boto3.Session().region_name\n",
5859
"\n",
@@ -83,7 +84,7 @@
8384
"metadata": {},
8485
"outputs": [],
8586
"source": [
86-
"!conda install -y -c conda-forge xgboost scikit-learn"
87+
"!conda install -y -c conda-forge xgboost"
8788
]
8889
},
8990
{
@@ -140,7 +141,9 @@
140141
{
141142
"cell_type": "code",
142143
"execution_count": null,
143-
"metadata": {},
144+
"metadata": {
145+
"collapsed": true
146+
},
144147
"outputs": [],
145148
"source": [
146149
"train_set, valid_set, test_set = get_dataset()\n",
@@ -165,7 +168,9 @@
165168
{
166169
"cell_type": "code",
167170
"execution_count": null,
168-
"metadata": {},
171+
"metadata": {
172+
"collapsed": true
173+
},
169174
"outputs": [],
170175
"source": [
171176
"import xgboost as xgb\n",
@@ -191,7 +196,9 @@
191196
{
192197
"cell_type": "code",
193198
"execution_count": null,
194-
"metadata": {},
199+
"metadata": {
200+
"collapsed": true
201+
},
195202
"outputs": [],
196203
"source": [
197204
"model_file_name = \"locally-trained-xgboost-model\"\n",
@@ -217,7 +224,9 @@
217224
{
218225
"cell_type": "code",
219226
"execution_count": null,
220-
"metadata": {},
227+
"metadata": {
228+
"collapsed": true
229+
},
221230
"outputs": [],
222231
"source": [
223232
"fObj = open(\"model.tar.gz\", 'rb')\n",
@@ -238,7 +247,9 @@
238247
{
239248
"cell_type": "code",
240249
"execution_count": null,
241-
"metadata": {},
250+
"metadata": {
251+
"collapsed": true
252+
},
242253
"outputs": [],
243254
"source": [
244255
"containers = {'us-west-2': '433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest',\n",
@@ -356,7 +367,9 @@
356367
{
357368
"cell_type": "code",
358369
"execution_count": null,
359-
"metadata": {},
370+
"metadata": {
371+
"collapsed": true
372+
},
360373
"outputs": [],
361374
"source": [
362375
"runtime_client = boto3.client('sagemaker-runtime')"
@@ -372,12 +385,15 @@
372385
{
373386
"cell_type": "code",
374387
"execution_count": null,
375-
"metadata": {},
388+
"metadata": {
389+
"collapsed": true
390+
},
376391
"outputs": [],
377392
"source": [
378393
"import numpy as np\n",
379-
"point_X = test_X[:1]\n",
380-
"point_y = test_y[:1]\n",
394+
"point_X = test_X[0]\n",
395+
"point_X = np.expand_dims(point_X, axis=0)\n",
396+
"point_y = test_y[0]\n",
381397
"np.savetxt(\"test_point.csv\", point_X, delimiter=\",\")"
382398
]
383399
},
@@ -397,7 +413,7 @@
397413
" payload = f.read().strip()\n",
398414
"\n",
399415
"response = runtime_client.invoke_endpoint(EndpointName=endpoint_name, \n",
400-
" ContentType='text/x-libsvm', \n",
416+
" ContentType='text/csv', \n",
401417
" Body=payload)\n",
402418
"result = response['Body'].read().decode('ascii')\n",
403419
"print('Predicted Class Probabilities: {}.'.format(result))"
@@ -417,8 +433,7 @@
417433
"metadata": {},
418434
"outputs": [],
419435
"source": [
420-
"arr = result[1:len(result)-1].split(',')\n",
421-
"floatArr = np.array(arr).astype(np.float)\n",
436+
"floatArr = np.array(json.loads(result))\n",
422437
"predictedLabel = np.argmax(floatArr)\n",
423438
"print('Predicted Class Label: {}.'.format(predictedLabel))\n",
424439
"print('Actual Class Label: {}.'.format(point_y))"
@@ -436,7 +451,9 @@
436451
{
437452
"cell_type": "code",
438453
"execution_count": null,
439-
"metadata": {},
454+
"metadata": {
455+
"collapsed": true
456+
},
440457
"outputs": [],
441458
"source": [
442459
"# sm_client.delete_endpoint(EndpointName=endpoint_name)"
@@ -446,7 +463,7 @@
446463
"metadata": {
447464
"anaconda-cloud": {},
448465
"kernelspec": {
449-
"display_name": "Environment (conda_python3)",
466+
"display_name": "conda_python3",
450467
"language": "python",
451468
"name": "conda_python3"
452469
},
@@ -460,7 +477,7 @@
460477
"name": "python",
461478
"nbconvert_exporter": "python",
462479
"pygments_lexer": "ipython3",
463-
"version": "3.6.3"
480+
"version": "3.6.2"
464481
},
465482
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
466483
},

0 commit comments

Comments
 (0)