-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Updated xgboost_customer_churn.ipynb for SageMaker SDK v2 #1803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,7 +49,6 @@ | |
"cell_type": "code", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add an import for sagemaker.inputs.TrainingInput
Reply via ReviewNB There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why neo support code is removed? is Neo available in all AWS regions now?
also, please change the Reply via ReviewNB There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you pls use a direct reference to Reply via ReviewNB |
||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true, | ||
"isConfigCell": true, | ||
"tags": [ | ||
"parameters" | ||
|
@@ -80,9 +79,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
|
@@ -95,8 +92,8 @@ | |
"import json\n", | ||
"from IPython.display import display\n", | ||
"from time import strftime, gmtime\n", | ||
"import sagemaker\n", | ||
"from sagemaker.predictor import csv_serializer" | ||
"from sagemaker.inputs import TrainingInput\n", | ||
"from sagemaker.serializers import CSVSerializer" | ||
] | ||
}, | ||
{ | ||
|
@@ -114,9 +111,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!wget http://dataminingconsultant.com/DKD2e_data_sets.zip\n", | ||
|
@@ -126,9 +121,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"churn = pd.read_csv('./Data sets/churn.txt')\n", | ||
|
@@ -166,9 +159,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Frequency tables for each categorical feature\n", | ||
|
@@ -195,9 +186,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"churn = churn.drop('Phone', axis=1)\n", | ||
|
@@ -214,9 +203,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for column in churn.select_dtypes(include=['object']).columns:\n", | ||
|
@@ -246,9 +233,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"display(churn.corr())\n", | ||
|
@@ -266,9 +251,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"churn = churn.drop(['Day Charge', 'Eve Charge', 'Night Charge', 'Intl Charge'], axis=1)" | ||
|
@@ -290,9 +273,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_data = pd.get_dummies(churn)\n", | ||
|
@@ -309,9 +290,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_data, validation_data, test_data = np.split(model_data.sample(frac=1, random_state=1729), [int(0.7 * len(model_data)), int(0.9 * len(model_data))])\n", | ||
|
@@ -329,9 +308,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train/train.csv')).upload_file('train.csv')\n", | ||
|
@@ -351,32 +328,28 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sagemaker.amazon.amazon_estimator import get_image_uri\n", | ||
"container = get_image_uri(boto3.Session().region_name, 'xgboost')" | ||
"container = sagemaker.image_uris.retrieve('xgboost', boto3.Session().region_name, '1')\n", | ||
"display(container)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Then, because we're training with the CSV file format, we'll create `s3_input`s that our training function can use as a pointer to the files in S3." | ||
"Then, because we're training with the CSV file format, we'll create `TrainingInput`s that our training function can use as a pointer to the files in S3." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"s3_input_train = sagemaker.s3_input(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n", | ||
"s3_input_validation = sagemaker.s3_input(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')" | ||
"s3_input_train = TrainingInput(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n", | ||
"s3_input_validation = TrainingInput(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')" | ||
] | ||
}, | ||
{ | ||
|
@@ -396,17 +369,15 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sess = sagemaker.Session()\n", | ||
"\n", | ||
"xgb = sagemaker.estimator.Estimator(container,\n", | ||
" role, \n", | ||
" train_instance_count=1, \n", | ||
" train_instance_type='ml.m4.xlarge',\n", | ||
" instance_count=1, \n", | ||
" instance_type='ml.m4.xlarge',\n", | ||
" output_path='s3://{}/{}/output'.format(bucket, prefix),\n", | ||
" sagemaker_session=sess)\n", | ||
"xgb.set_hyperparameters(max_depth=5,\n", | ||
|
@@ -437,18 +408,14 @@ | |
"outputs": [], | ||
"source": [ | ||
"compiled_model = xgb\n", | ||
"if xgb.create_model().check_neo_region(boto3.Session().region_name) is False:\n", | ||
" print('Neo is not currently supported in', boto3.Session().region_name)\n", | ||
"else:\n", | ||
" output_path = '/'.join(xgb.output_path.split('/')[:-1])\n", | ||
" compiled_model = xgb.compile_model(target_instance_family='ml_m4', \n", | ||
" input_shape={'data': [1, 69]},\n", | ||
" role=role,\n", | ||
" framework='xgboost',\n", | ||
" framework_version='0.7',\n", | ||
" output_path=output_path)\n", | ||
" compiled_model.name = 'deployed-xgboost-customer-churn'\n", | ||
" compiled_model.image = get_image_uri(sess.boto_region_name, 'xgboost-neo', repo_version='latest')" | ||
"output_path = '/'.join(xgb.output_path.split('/')[:-1])\n", | ||
"compiled_model = xgb.compile_model(target_instance_family='ml_m4', \n", | ||
" input_shape={'data': [1, 69]},\n", | ||
" role=role,\n", | ||
" framework='xgboost',\n", | ||
" framework_version='latest',\n", | ||
" output_path=output_path)\n", | ||
"compiled_model.name = 'deployed-xgboost-customer-churn'" | ||
] | ||
}, | ||
{ | ||
|
@@ -464,12 +431,13 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"xgb_predictor = compiled_model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')" | ||
"xgb_predictor = compiled_model.deploy(\n", | ||
" initial_instance_count = 1, \n", | ||
" instance_type = 'ml.m4.xlarge',\n", | ||
" serializer=CSVSerializer())" | ||
] | ||
}, | ||
{ | ||
|
@@ -481,19 +449,6 @@ | |
"Now that we have a hosted endpoint running, we can make real-time predictions from our model very easily, simply by making an http POST request. But first, we'll need to setup serializers and deserializers for passing our `test_data` NumPy arrays to the model behind the endpoint." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"xgb_predictor.content_type = 'text/csv'\n", | ||
"xgb_predictor.serializer = csv_serializer\n", | ||
"xgb_predictor.deserializer = None" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
|
@@ -509,9 +464,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def predict(data, rows=500):\n", | ||
|
@@ -535,9 +488,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pd.crosstab(index=test_data.iloc[:, 0], columns=np.round(predictions), rownames=['actual'], colnames=['predictions'])" | ||
|
@@ -559,9 +510,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"plt.hist(predictions)\n", | ||
|
@@ -578,9 +527,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > 0.3, 1, 0))" | ||
|
@@ -629,9 +576,7 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"cutoffs = np.arange(0.01, 1, 0.01)\n", | ||
|
@@ -643,7 +588,15 @@ | |
"\n", | ||
"costs = np.array(costs)\n", | ||
"plt.plot(cutoffs, costs)\n", | ||
"plt.show()\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print('Cost is minimized near a cutoff of:', cutoffs[np.argmin(costs)], 'for a cost of:', np.min(costs))" | ||
] | ||
}, | ||
|
@@ -683,21 +636,26 @@ | |
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sagemaker.Session().delete_endpoint(xgb_predictor.endpoint)" | ||
"xgb_predictor.delete_endpoint()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"celltoolbar": "Tags", | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"display_name": "conda_python3", | ||
"language": "python", | ||
"name": "python3" | ||
"name": "conda_python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
|
@@ -709,7 +667,7 @@ | |
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.3" | ||
"version": "3.6.10" | ||
}, | ||
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." | ||
}, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove the old import of SDK V1. also if you can, pls move all the imports in the beginning cells itself.
Reply via ReviewNB