Skip to content

Commit d6202d5

Browse files
committed
Updated xgboost_customer_churn.ipynb for SageMaker SDK v2
Notebook needed updates to be compatible with the v2 of the SageMaker SDK, which is the new Default SDK version in SageMaker.
1 parent aa22e25 commit d6202d5

File tree

1 file changed

+50
-110
lines changed

1 file changed

+50
-110
lines changed

introduction_to_applying_machine_learning/xgboost_customer_churn/xgboost_customer_churn.ipynb

Lines changed: 50 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
"cell_type": "code",
5050
"execution_count": null,
5151
"metadata": {
52-
"collapsed": true,
5352
"isConfigCell": true,
5453
"tags": [
5554
"parameters"
@@ -80,9 +79,7 @@
8079
{
8180
"cell_type": "code",
8281
"execution_count": null,
83-
"metadata": {
84-
"collapsed": true
85-
},
82+
"metadata": {},
8683
"outputs": [],
8784
"source": [
8885
"import pandas as pd\n",
@@ -114,9 +111,7 @@
114111
{
115112
"cell_type": "code",
116113
"execution_count": null,
117-
"metadata": {
118-
"collapsed": true
119-
},
114+
"metadata": {},
120115
"outputs": [],
121116
"source": [
122117
"!wget http://dataminingconsultant.com/DKD2e_data_sets.zip\n",
@@ -126,9 +121,7 @@
126121
{
127122
"cell_type": "code",
128123
"execution_count": null,
129-
"metadata": {
130-
"collapsed": true
131-
},
124+
"metadata": {},
132125
"outputs": [],
133126
"source": [
134127
"churn = pd.read_csv('./Data sets/churn.txt')\n",
@@ -166,9 +159,7 @@
166159
{
167160
"cell_type": "code",
168161
"execution_count": null,
169-
"metadata": {
170-
"collapsed": true
171-
},
162+
"metadata": {},
172163
"outputs": [],
173164
"source": [
174165
"# Frequency tables for each categorical feature\n",
@@ -195,9 +186,7 @@
195186
{
196187
"cell_type": "code",
197188
"execution_count": null,
198-
"metadata": {
199-
"collapsed": true
200-
},
189+
"metadata": {},
201190
"outputs": [],
202191
"source": [
203192
"churn = churn.drop('Phone', axis=1)\n",
@@ -214,9 +203,7 @@
214203
{
215204
"cell_type": "code",
216205
"execution_count": null,
217-
"metadata": {
218-
"collapsed": true
219-
},
206+
"metadata": {},
220207
"outputs": [],
221208
"source": [
222209
"for column in churn.select_dtypes(include=['object']).columns:\n",
@@ -246,9 +233,7 @@
246233
{
247234
"cell_type": "code",
248235
"execution_count": null,
249-
"metadata": {
250-
"collapsed": true
251-
},
236+
"metadata": {},
252237
"outputs": [],
253238
"source": [
254239
"display(churn.corr())\n",
@@ -266,9 +251,7 @@
266251
{
267252
"cell_type": "code",
268253
"execution_count": null,
269-
"metadata": {
270-
"collapsed": true
271-
},
254+
"metadata": {},
272255
"outputs": [],
273256
"source": [
274257
"churn = churn.drop(['Day Charge', 'Eve Charge', 'Night Charge', 'Intl Charge'], axis=1)"
@@ -290,9 +273,7 @@
290273
{
291274
"cell_type": "code",
292275
"execution_count": null,
293-
"metadata": {
294-
"collapsed": true
295-
},
276+
"metadata": {},
296277
"outputs": [],
297278
"source": [
298279
"model_data = pd.get_dummies(churn)\n",
@@ -309,9 +290,7 @@
309290
{
310291
"cell_type": "code",
311292
"execution_count": null,
312-
"metadata": {
313-
"collapsed": true
314-
},
293+
"metadata": {},
315294
"outputs": [],
316295
"source": [
317296
"train_data, validation_data, test_data = np.split(model_data.sample(frac=1, random_state=1729), [int(0.7 * len(model_data)), int(0.9 * len(model_data))])\n",
@@ -329,9 +308,7 @@
329308
{
330309
"cell_type": "code",
331310
"execution_count": null,
332-
"metadata": {
333-
"collapsed": true
334-
},
311+
"metadata": {},
335312
"outputs": [],
336313
"source": [
337314
"boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train/train.csv')).upload_file('train.csv')\n",
@@ -351,13 +328,12 @@
351328
{
352329
"cell_type": "code",
353330
"execution_count": null,
354-
"metadata": {
355-
"collapsed": true
356-
},
331+
"metadata": {},
357332
"outputs": [],
358333
"source": [
359334
"from sagemaker.amazon.amazon_estimator import get_image_uri\n",
360-
"container = get_image_uri(boto3.Session().region_name, 'xgboost')"
335+
"container = sagemaker.image_uris.retrieve('xgboost', boto3.Session().region_name, '1')\n",
336+
"display(container)"
361337
]
362338
},
363339
{
@@ -370,13 +346,11 @@
370346
{
371347
"cell_type": "code",
372348
"execution_count": null,
373-
"metadata": {
374-
"collapsed": true
375-
},
349+
"metadata": {},
376350
"outputs": [],
377351
"source": [
378-
"s3_input_train = sagemaker.s3_input(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n",
379-
"s3_input_validation = sagemaker.s3_input(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')"
352+
"s3_input_train = sagemaker.inputs.TrainingInput(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n",
353+
"s3_input_validation = sagemaker.inputs.TrainingInput(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')"
380354
]
381355
},
382356
{
@@ -396,17 +370,15 @@
396370
{
397371
"cell_type": "code",
398372
"execution_count": null,
399-
"metadata": {
400-
"collapsed": true
401-
},
373+
"metadata": {},
402374
"outputs": [],
403375
"source": [
404376
"sess = sagemaker.Session()\n",
405377
"\n",
406378
"xgb = sagemaker.estimator.Estimator(container,\n",
407379
" role, \n",
408-
" train_instance_count=1, \n",
409-
" train_instance_type='ml.m4.xlarge',\n",
380+
" instance_count=1, \n",
381+
" instance_type='ml.m4.xlarge',\n",
410382
" output_path='s3://{}/{}/output'.format(bucket, prefix),\n",
411383
" sagemaker_session=sess)\n",
412384
"xgb.set_hyperparameters(max_depth=5,\n",
@@ -437,18 +409,15 @@
437409
"outputs": [],
438410
"source": [
439411
"compiled_model = xgb\n",
440-
"if xgb.create_model().check_neo_region(boto3.Session().region_name) is False:\n",
441-
" print('Neo is not currently supported in', boto3.Session().region_name)\n",
442-
"else:\n",
443-
" output_path = '/'.join(xgb.output_path.split('/')[:-1])\n",
444-
" compiled_model = xgb.compile_model(target_instance_family='ml_m4', \n",
445-
" input_shape={'data': [1, 69]},\n",
446-
" role=role,\n",
447-
" framework='xgboost',\n",
448-
" framework_version='0.7',\n",
449-
" output_path=output_path)\n",
450-
" compiled_model.name = 'deployed-xgboost-customer-churn'\n",
451-
" compiled_model.image = get_image_uri(sess.boto_region_name, 'xgboost-neo', repo_version='latest')"
412+
"output_path = '/'.join(xgb.output_path.split('/')[:-1])\n",
413+
"compiled_model = xgb.compile_model(target_instance_family='ml_m4', \n",
414+
" input_shape={'data': [1, 69]},\n",
415+
" role=role,\n",
416+
" framework='xgboost',\n",
417+
" framework_version='latest',\n",
418+
" output_path=output_path)\n",
419+
"compiled_model.name = 'deployed-xgboost-customer-churn'\n",
420+
"compiled_model.image = get_image_uri(sess.boto_region_name, 'xgboost-neo', repo_version='latest')"
452421
]
453422
},
454423
{
@@ -464,12 +433,13 @@
464433
{
465434
"cell_type": "code",
466435
"execution_count": null,
467-
"metadata": {
468-
"collapsed": true
469-
},
436+
"metadata": {},
470437
"outputs": [],
471438
"source": [
472-
"xgb_predictor = compiled_model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')"
439+
"xgb_predictor = compiled_model.deploy(\n",
440+
" initial_instance_count = 1, \n",
441+
" instance_type = 'ml.m4.xlarge',\n",
442+
" serializer=sagemaker.serializers.CSVSerializer())"
473443
]
474444
},
475445
{
@@ -481,19 +451,6 @@
481451
"Now that we have a hosted endpoint running, we can make real-time predictions from our model very easily, simply by making an http POST request. But first, we'll need to setup serializers and deserializers for passing our `test_data` NumPy arrays to the model behind the endpoint."
482452
]
483453
},
484-
{
485-
"cell_type": "code",
486-
"execution_count": null,
487-
"metadata": {
488-
"collapsed": true
489-
},
490-
"outputs": [],
491-
"source": [
492-
"xgb_predictor.content_type = 'text/csv'\n",
493-
"xgb_predictor.serializer = csv_serializer\n",
494-
"xgb_predictor.deserializer = None"
495-
]
496-
},
497454
{
498455
"cell_type": "markdown",
499456
"metadata": {},
@@ -509,9 +466,7 @@
509466
{
510467
"cell_type": "code",
511468
"execution_count": null,
512-
"metadata": {
513-
"collapsed": true
514-
},
469+
"metadata": {},
515470
"outputs": [],
516471
"source": [
517472
"def predict(data, rows=500):\n",
@@ -535,9 +490,7 @@
535490
{
536491
"cell_type": "code",
537492
"execution_count": null,
538-
"metadata": {
539-
"collapsed": true
540-
},
493+
"metadata": {},
541494
"outputs": [],
542495
"source": [
543496
"pd.crosstab(index=test_data.iloc[:, 0], columns=np.round(predictions), rownames=['actual'], colnames=['predictions'])"
@@ -559,9 +512,7 @@
559512
{
560513
"cell_type": "code",
561514
"execution_count": null,
562-
"metadata": {
563-
"collapsed": true
564-
},
515+
"metadata": {},
565516
"outputs": [],
566517
"source": [
567518
"plt.hist(predictions)\n",
@@ -578,9 +529,7 @@
578529
{
579530
"cell_type": "code",
580531
"execution_count": null,
581-
"metadata": {
582-
"collapsed": true
583-
},
532+
"metadata": {},
584533
"outputs": [],
585534
"source": [
586535
"pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > 0.3, 1, 0))"
@@ -629,9 +578,7 @@
629578
{
630579
"cell_type": "code",
631580
"execution_count": null,
632-
"metadata": {
633-
"collapsed": true
634-
},
581+
"metadata": {},
635582
"outputs": [],
636583
"source": [
637584
"cutoffs = np.arange(0.01, 1, 0.01)\n",
@@ -683,33 +630,26 @@
683630
{
684631
"cell_type": "code",
685632
"execution_count": null,
686-
"metadata": {
687-
"collapsed": true
688-
},
633+
"metadata": {},
689634
"outputs": [],
690635
"source": [
691-
"sagemaker.Session().delete_endpoint(xgb_predictor.endpoint)"
636+
"xgb_predictor.delete_endpoint()"
692637
]
638+
},
639+
{
640+
"cell_type": "code",
641+
"execution_count": null,
642+
"metadata": {},
643+
"outputs": [],
644+
"source": []
693645
}
694646
],
695647
"metadata": {
696648
"celltoolbar": "Tags",
697649
"kernelspec": {
698-
"display_name": "Python 3",
650+
"display_name": "conda_python3",
699651
"language": "python",
700-
"name": "python3"
701-
},
702-
"language_info": {
703-
"codemirror_mode": {
704-
"name": "ipython",
705-
"version": 3
706-
},
707-
"file_extension": ".py",
708-
"mimetype": "text/x-python",
709-
"name": "python",
710-
"nbconvert_exporter": "python",
711-
"pygments_lexer": "ipython3",
712-
"version": "3.7.3"
652+
"name": "conda_python3"
713653
},
714654
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
715655
},

0 commit comments

Comments
 (0)