Skip to content

Commit 5feab2c

Browse files
authored
Updated xgboost_customer_churn.ipynb for SageMaker SDK v2 (#1803)
* Updated xgboost_customer_churn.ipynb for SageMaker SDK v2 Notebook needed updates to be compatible with the v2 of the SageMaker SDK, which is the new Default SDK version in SageMaker. * Fixed xgboost_customer_churn.ipynb after code review 1. Remove the old import of SDK V1. 2. Added imports for TrainingInput and CSVSerializer to keep code cleaner. 3. The check_neo_region was removed since it is not supported in SDK v2, and from Sep 26, 2019, Amazon SageMaker Neo is Available in 12 additional Regions. 5. get_image_uri in Neo compilation was removed since it is not needed at all. 5. Moved the print of cutoffs, and costs to a diffferent cell since it used to throw an error on the first run.
1 parent 8659689 commit 5feab2c

File tree

1 file changed

+62
-104
lines changed

1 file changed

+62
-104
lines changed

introduction_to_applying_machine_learning/xgboost_customer_churn/xgboost_customer_churn.ipynb

Lines changed: 62 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
"cell_type": "code",
5050
"execution_count": null,
5151
"metadata": {
52-
"collapsed": true,
5352
"isConfigCell": true,
5453
"tags": [
5554
"parameters"
@@ -80,9 +79,7 @@
8079
{
8180
"cell_type": "code",
8281
"execution_count": null,
83-
"metadata": {
84-
"collapsed": true
85-
},
82+
"metadata": {},
8683
"outputs": [],
8784
"source": [
8885
"import pandas as pd\n",
@@ -95,8 +92,8 @@
9592
"import json\n",
9693
"from IPython.display import display\n",
9794
"from time import strftime, gmtime\n",
98-
"import sagemaker\n",
99-
"from sagemaker.predictor import csv_serializer"
95+
"from sagemaker.inputs import TrainingInput\n",
96+
"from sagemaker.serializers import CSVSerializer"
10097
]
10198
},
10299
{
@@ -114,9 +111,7 @@
114111
{
115112
"cell_type": "code",
116113
"execution_count": null,
117-
"metadata": {
118-
"collapsed": true
119-
},
114+
"metadata": {},
120115
"outputs": [],
121116
"source": [
122117
"!wget http://dataminingconsultant.com/DKD2e_data_sets.zip\n",
@@ -126,9 +121,7 @@
126121
{
127122
"cell_type": "code",
128123
"execution_count": null,
129-
"metadata": {
130-
"collapsed": true
131-
},
124+
"metadata": {},
132125
"outputs": [],
133126
"source": [
134127
"churn = pd.read_csv('./Data sets/churn.txt')\n",
@@ -166,9 +159,7 @@
166159
{
167160
"cell_type": "code",
168161
"execution_count": null,
169-
"metadata": {
170-
"collapsed": true
171-
},
162+
"metadata": {},
172163
"outputs": [],
173164
"source": [
174165
"# Frequency tables for each categorical feature\n",
@@ -195,9 +186,7 @@
195186
{
196187
"cell_type": "code",
197188
"execution_count": null,
198-
"metadata": {
199-
"collapsed": true
200-
},
189+
"metadata": {},
201190
"outputs": [],
202191
"source": [
203192
"churn = churn.drop('Phone', axis=1)\n",
@@ -214,9 +203,7 @@
214203
{
215204
"cell_type": "code",
216205
"execution_count": null,
217-
"metadata": {
218-
"collapsed": true
219-
},
206+
"metadata": {},
220207
"outputs": [],
221208
"source": [
222209
"for column in churn.select_dtypes(include=['object']).columns:\n",
@@ -246,9 +233,7 @@
246233
{
247234
"cell_type": "code",
248235
"execution_count": null,
249-
"metadata": {
250-
"collapsed": true
251-
},
236+
"metadata": {},
252237
"outputs": [],
253238
"source": [
254239
"display(churn.corr())\n",
@@ -266,9 +251,7 @@
266251
{
267252
"cell_type": "code",
268253
"execution_count": null,
269-
"metadata": {
270-
"collapsed": true
271-
},
254+
"metadata": {},
272255
"outputs": [],
273256
"source": [
274257
"churn = churn.drop(['Day Charge', 'Eve Charge', 'Night Charge', 'Intl Charge'], axis=1)"
@@ -290,9 +273,7 @@
290273
{
291274
"cell_type": "code",
292275
"execution_count": null,
293-
"metadata": {
294-
"collapsed": true
295-
},
276+
"metadata": {},
296277
"outputs": [],
297278
"source": [
298279
"model_data = pd.get_dummies(churn)\n",
@@ -309,9 +290,7 @@
309290
{
310291
"cell_type": "code",
311292
"execution_count": null,
312-
"metadata": {
313-
"collapsed": true
314-
},
293+
"metadata": {},
315294
"outputs": [],
316295
"source": [
317296
"train_data, validation_data, test_data = np.split(model_data.sample(frac=1, random_state=1729), [int(0.7 * len(model_data)), int(0.9 * len(model_data))])\n",
@@ -329,9 +308,7 @@
329308
{
330309
"cell_type": "code",
331310
"execution_count": null,
332-
"metadata": {
333-
"collapsed": true
334-
},
311+
"metadata": {},
335312
"outputs": [],
336313
"source": [
337314
"boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train/train.csv')).upload_file('train.csv')\n",
@@ -351,32 +328,28 @@
351328
{
352329
"cell_type": "code",
353330
"execution_count": null,
354-
"metadata": {
355-
"collapsed": true
356-
},
331+
"metadata": {},
357332
"outputs": [],
358333
"source": [
359-
"from sagemaker.amazon.amazon_estimator import get_image_uri\n",
360-
"container = get_image_uri(boto3.Session().region_name, 'xgboost')"
334+
"container = sagemaker.image_uris.retrieve('xgboost', boto3.Session().region_name, '1')\n",
335+
"display(container)"
361336
]
362337
},
363338
{
364339
"cell_type": "markdown",
365340
"metadata": {},
366341
"source": [
367-
"Then, because we're training with the CSV file format, we'll create `s3_input`s that our training function can use as a pointer to the files in S3."
342+
"Then, because we're training with the CSV file format, we'll create `TrainingInput`s that our training function can use as a pointer to the files in S3."
368343
]
369344
},
370345
{
371346
"cell_type": "code",
372347
"execution_count": null,
373-
"metadata": {
374-
"collapsed": true
375-
},
348+
"metadata": {},
376349
"outputs": [],
377350
"source": [
378-
"s3_input_train = sagemaker.s3_input(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n",
379-
"s3_input_validation = sagemaker.s3_input(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')"
351+
"s3_input_train = TrainingInput(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n",
352+
"s3_input_validation = TrainingInput(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')"
380353
]
381354
},
382355
{
@@ -396,17 +369,15 @@
396369
{
397370
"cell_type": "code",
398371
"execution_count": null,
399-
"metadata": {
400-
"collapsed": true
401-
},
372+
"metadata": {},
402373
"outputs": [],
403374
"source": [
404375
"sess = sagemaker.Session()\n",
405376
"\n",
406377
"xgb = sagemaker.estimator.Estimator(container,\n",
407378
" role, \n",
408-
" train_instance_count=1, \n",
409-
" train_instance_type='ml.m4.xlarge',\n",
379+
" instance_count=1, \n",
380+
" instance_type='ml.m4.xlarge',\n",
410381
" output_path='s3://{}/{}/output'.format(bucket, prefix),\n",
411382
" sagemaker_session=sess)\n",
412383
"xgb.set_hyperparameters(max_depth=5,\n",
@@ -437,18 +408,14 @@
437408
"outputs": [],
438409
"source": [
439410
"compiled_model = xgb\n",
440-
"if xgb.create_model().check_neo_region(boto3.Session().region_name) is False:\n",
441-
" print('Neo is not currently supported in', boto3.Session().region_name)\n",
442-
"else:\n",
443-
" output_path = '/'.join(xgb.output_path.split('/')[:-1])\n",
444-
" compiled_model = xgb.compile_model(target_instance_family='ml_m4', \n",
445-
" input_shape={'data': [1, 69]},\n",
446-
" role=role,\n",
447-
" framework='xgboost',\n",
448-
" framework_version='0.7',\n",
449-
" output_path=output_path)\n",
450-
" compiled_model.name = 'deployed-xgboost-customer-churn'\n",
451-
" compiled_model.image = get_image_uri(sess.boto_region_name, 'xgboost-neo', repo_version='latest')"
411+
"output_path = '/'.join(xgb.output_path.split('/')[:-1])\n",
412+
"compiled_model = xgb.compile_model(target_instance_family='ml_m4', \n",
413+
" input_shape={'data': [1, 69]},\n",
414+
" role=role,\n",
415+
" framework='xgboost',\n",
416+
" framework_version='latest',\n",
417+
" output_path=output_path)\n",
418+
"compiled_model.name = 'deployed-xgboost-customer-churn'"
452419
]
453420
},
454421
{
@@ -464,12 +431,13 @@
464431
{
465432
"cell_type": "code",
466433
"execution_count": null,
467-
"metadata": {
468-
"collapsed": true
469-
},
434+
"metadata": {},
470435
"outputs": [],
471436
"source": [
472-
"xgb_predictor = compiled_model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')"
437+
"xgb_predictor = compiled_model.deploy(\n",
438+
" initial_instance_count = 1, \n",
439+
" instance_type = 'ml.m4.xlarge',\n",
440+
" serializer=CSVSerializer())"
473441
]
474442
},
475443
{
@@ -481,19 +449,6 @@
481449
"Now that we have a hosted endpoint running, we can make real-time predictions from our model very easily, simply by making an http POST request. But first, we'll need to setup serializers and deserializers for passing our `test_data` NumPy arrays to the model behind the endpoint."
482450
]
483451
},
484-
{
485-
"cell_type": "code",
486-
"execution_count": null,
487-
"metadata": {
488-
"collapsed": true
489-
},
490-
"outputs": [],
491-
"source": [
492-
"xgb_predictor.content_type = 'text/csv'\n",
493-
"xgb_predictor.serializer = csv_serializer\n",
494-
"xgb_predictor.deserializer = None"
495-
]
496-
},
497452
{
498453
"cell_type": "markdown",
499454
"metadata": {},
@@ -509,9 +464,7 @@
509464
{
510465
"cell_type": "code",
511466
"execution_count": null,
512-
"metadata": {
513-
"collapsed": true
514-
},
467+
"metadata": {},
515468
"outputs": [],
516469
"source": [
517470
"def predict(data, rows=500):\n",
@@ -535,9 +488,7 @@
535488
{
536489
"cell_type": "code",
537490
"execution_count": null,
538-
"metadata": {
539-
"collapsed": true
540-
},
491+
"metadata": {},
541492
"outputs": [],
542493
"source": [
543494
"pd.crosstab(index=test_data.iloc[:, 0], columns=np.round(predictions), rownames=['actual'], colnames=['predictions'])"
@@ -559,9 +510,7 @@
559510
{
560511
"cell_type": "code",
561512
"execution_count": null,
562-
"metadata": {
563-
"collapsed": true
564-
},
513+
"metadata": {},
565514
"outputs": [],
566515
"source": [
567516
"plt.hist(predictions)\n",
@@ -578,9 +527,7 @@
578527
{
579528
"cell_type": "code",
580529
"execution_count": null,
581-
"metadata": {
582-
"collapsed": true
583-
},
530+
"metadata": {},
584531
"outputs": [],
585532
"source": [
586533
"pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > 0.3, 1, 0))"
@@ -629,9 +576,7 @@
629576
{
630577
"cell_type": "code",
631578
"execution_count": null,
632-
"metadata": {
633-
"collapsed": true
634-
},
579+
"metadata": {},
635580
"outputs": [],
636581
"source": [
637582
"cutoffs = np.arange(0.01, 1, 0.01)\n",
@@ -643,7 +588,15 @@
643588
"\n",
644589
"costs = np.array(costs)\n",
645590
"plt.plot(cutoffs, costs)\n",
646-
"plt.show()\n",
591+
"plt.show()"
592+
]
593+
},
594+
{
595+
"cell_type": "code",
596+
"execution_count": null,
597+
"metadata": {},
598+
"outputs": [],
599+
"source": [
647600
"print('Cost is minimized near a cutoff of:', cutoffs[np.argmin(costs)], 'for a cost of:', np.min(costs))"
648601
]
649602
},
@@ -683,21 +636,26 @@
683636
{
684637
"cell_type": "code",
685638
"execution_count": null,
686-
"metadata": {
687-
"collapsed": true
688-
},
639+
"metadata": {},
689640
"outputs": [],
690641
"source": [
691-
"sagemaker.Session().delete_endpoint(xgb_predictor.endpoint)"
642+
"xgb_predictor.delete_endpoint()"
692643
]
644+
},
645+
{
646+
"cell_type": "code",
647+
"execution_count": null,
648+
"metadata": {},
649+
"outputs": [],
650+
"source": []
693651
}
694652
],
695653
"metadata": {
696654
"celltoolbar": "Tags",
697655
"kernelspec": {
698-
"display_name": "Python 3",
656+
"display_name": "conda_python3",
699657
"language": "python",
700-
"name": "python3"
658+
"name": "conda_python3"
701659
},
702660
"language_info": {
703661
"codemirror_mode": {
@@ -709,7 +667,7 @@
709667
"name": "python",
710668
"nbconvert_exporter": "python",
711669
"pygments_lexer": "ipython3",
712-
"version": "3.7.3"
670+
"version": "3.6.10"
713671
},
714672
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
715673
},

0 commit comments

Comments
 (0)