Skip to content

Commit d003ef0

Browse files
authored
change: updated for sagemaker python sdk 2.x (#1667)
1 parent ffee0c8 commit d003ef0

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

autopilot/autopilot_customer_churn_high_level_with_evaluation.ipynb

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"\n",
2323
"---\n",
2424
"\n",
25-
"This notebook works with sagemaker python sdk >= 1.65.1.\n",
25+
"This notebook works with sagemaker python sdk 2.x\n",
2626
"\n",
2727
"## Contents\n",
2828
"\n",
@@ -794,17 +794,19 @@
794794
"metadata": {},
795795
"outputs": [],
796796
"source": [
797-
"from sagemaker.predictor import RealTimePredictor\n",
798-
"from sagemaker.content_types import CONTENT_TYPE_CSV\n",
797+
"from sagemaker.predictor import Predictor\n",
798+
"from sagemaker.serializers import CSVSerializer\n",
799+
"from sagemaker.deserializers import CSVDeserializer\n",
799800
"\n",
800801
"predictor = automl.deploy(initial_instance_count=1,\n",
801802
" instance_type='ml.m5.2xlarge',\n",
802803
" candidate=candidates[best_candidate_idx],\n",
803804
" inference_response_keys=inference_response_keys,\n",
804-
" predictor_cls=RealTimePredictor)\n",
805-
"predictor.content_type = CONTENT_TYPE_CSV\n",
805+
" predictor_cls=Predictor,\n",
806+
" serializer=CSVSerializer(),\n",
807+
" deserializer=CSVDeserializer())\n",
806808
"\n",
807-
"print(\"Created endpoint: {}\".format(predictor.endpoint))\n"
809+
"print(\"Created endpoint: {}\".format(predictor.endpoint_name))\n"
808810
]
809811
},
810812
{
@@ -829,11 +831,9 @@
829831
"metadata": {},
830832
"outputs": [],
831833
"source": [
832-
"from io import StringIO\n",
833-
"\n",
834-
"prediction = predictor.predict(test_data_no_target.to_csv(sep=',', header=False, index=False)).decode('utf-8')\n",
835-
"prediction_df = pd.read_csv(StringIO(prediction), header=None, names=inference_response_keys)\n",
836-
"custom_predicted_labels = prediction_df.iloc[:,1].values >= best_candidate_threshold\n",
834+
"prediction = predictor.predict(test_data_no_target.to_csv(sep=',', header=False, index=False))\n",
835+
"prediction_df = pd.DataFrame(prediction, columns=inference_response_keys)\n",
836+
"custom_predicted_labels = prediction_df.iloc[:,1].astype(float).values >= best_candidate_threshold\n",
837837
"prediction_df['custom_predicted_label'] = custom_predicted_labels\n",
838838
"prediction_df['custom_predicted_label'] = prediction_df['custom_predicted_label'].map({False: target_attribute_values[0], True: target_attribute_values[1]})\n",
839839
"prediction_df"
@@ -906,4 +906,4 @@
906906
},
907907
"nbformat": 4,
908908
"nbformat_minor": 4
909-
}
909+
}

0 commit comments

Comments
 (0)