Skip to content

Commit 1d8f732

Browse files
committed
add sagemaker spark kmeans mnist notebook
1 parent 0636049 commit 1d8f732

File tree

1 file changed

+260
-0
lines changed

1 file changed

+260
-0
lines changed
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n",
8+
"\n",
9+
"Licensed under the Apache License, Version 2.0 (the \"License\").\n",
10+
"You may not use this file except in compliance with the License.\n",
11+
"A copy of the License is located at\n",
12+
" \n",
13+
" http://aws.amazon.com/apache2.0/\n",
14+
"\n",
15+
"or in the \"license\" file accompanying this file. This file is distributed\n",
16+
"on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either\n",
17+
"express or implied. See the License for the specific language governing\n",
18+
"permissions and limitations under the License."
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"# SageMakerPySpark MNIST Example\n",
26+
"\n",
27+
"1. [Introduction](#Introduction)\n",
28+
"2. [Data Inspection](#Data-Inspection)\n",
29+
"3. [Training the K-Means Model](#Training-the-K-Means-Model)\n",
30+
"4. [Validate the Model for use](#Validate-the-Model-for-use)\n",
31+
"5. [Bring your Own Algorithm](#Bring-your-Own-Algorithm)\n"
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"metadata": {},
37+
"source": [
38+
"## Introduction\n",
39+
"This notebook will show how to classify handwritten digits using the KMeans clustering algorithm through the SageMakerPySparkSDK.\n",
40+
"\n",
41+
"You can visit SageMaker Spark's Github repository at https://github.com/aws/sagemaker-spark for more about SageMaker Spark.\n",
42+
"\n",
43+
"We will train on Amazon SageMaker using the KMeans Clustering on the MNIST dataset, host the trained model on Amazon SageMaker, and then make predictions against that hosted model.\n",
44+
"\n",
45+
"First, we load the MNIST dataset into a Spark Dataframe, which dataset is available in LibSVM format at\n",
46+
"\n",
47+
"s3://sagemaker-sample-data-[region, such as us-east-1]/spark/mnist/train/"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"from pyspark import SparkContext, SparkConf\n",
57+
"from pyspark.sql import SparkSession\n",
58+
"import os\n",
59+
"import sagemaker_pyspark\n",
60+
"import sagemaker\n",
61+
"from sagemaker import get_execution_role\n",
62+
"\n",
63+
"sagemaker_session = sagemaker.Session()\n",
64+
"\n",
65+
"role = get_execution_role()\n",
66+
"\n",
67+
"# Configure Spark to use the SageMaker Spark dependency jars\n",
68+
"jars = sagemaker_pyspark.classpath_jars()\n",
69+
"\n",
70+
"classpath = \":\".join(sagemaker_pyspark.classpath_jars())\n",
71+
"\n",
72+
"# See the SageMaker Spark Github repo under sagemaker-pyspark-sdk\n",
73+
"# to learn how to connect to a remote EMR cluster running Spark from a Notebook Instance.\n",
74+
"spark = SparkSession.builder.config(\"spark.driver.extraClassPath\", classpath)\\\n",
75+
" .master(\"local[*]\").getOrCreate()"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"# replace this with your own region, such as us-east-1\n",
85+
"region = 'us-east-1'\n",
86+
"trainingData = spark.read.format('libsvm')\\\n",
87+
" .option('numFeatures', '784')\\\n",
88+
" .load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region))\n",
89+
"\n",
90+
"testData = spark.read.format('libsvm')\\\n",
91+
" .option('numFeatures', '784')\\\n",
92+
" .load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region))"
93+
]
94+
},
95+
{
96+
"cell_type": "markdown",
97+
"metadata": {},
98+
"source": [
99+
"## Data Inspection\n",
100+
"In order to train and make inferences our input DataFrame must have a column of Doubles (named \"label\" by default) and a column of Vectors of Doubles (named \"features\" by default).\n",
101+
"\n",
102+
"Spark's LibSVM DataFrameReader loads a DataFrame already suitable for training and inference."
103+
]
104+
},
105+
{
106+
"cell_type": "code",
107+
"execution_count": null,
108+
"metadata": {},
109+
"outputs": [],
110+
"source": [
111+
"trainingData.show()"
112+
]
113+
},
114+
{
115+
"cell_type": "markdown",
116+
"metadata": {},
117+
"source": [
118+
"## Training the K-Means Model\n",
119+
"Now we create a KMeansSageMakerEstimator, which uses the KMeans Amazon SageMaker Algorithm to train on our input data, and uses the KMeans Amazon SageMaker model image to host our model.\n",
120+
"\n",
121+
"Calling fit() on this estimator will train our model on Amazon SageMaker, and then create an Amazon SageMaker Endpoint to host our model.\n",
122+
"\n",
123+
"We can then use the SageMakerModel returned by this call to fit() to transform Dataframes using our hosted model.\n",
124+
"\n",
125+
"The following cell runs a training job and creates an endpoint to host the resulting model, so this cell can take up to twenty minutes to complete."
126+
]
127+
},
128+
{
129+
"cell_type": "code",
130+
"execution_count": null,
131+
"metadata": {},
132+
"outputs": [],
133+
"source": [
134+
"import random\n",
135+
"from sagemaker_pyspark import IAMRole, S3DataPath\n",
136+
"from sagemaker_pyspark.algorithms import KMeansSageMakerEstimator\n",
137+
"\n",
138+
"# replace this with your role ARN\n",
139+
"kmeans_estimator = KMeansSageMakerEstimator(\n",
140+
" sagemakerRole=IAMRole(role),\n",
141+
" trainingInstanceType='ml.p2.xlarge',\n",
142+
" trainingInstanceCount=1,\n",
143+
" endpointInstanceType='ml.c4.xlarge',\n",
144+
" endpointInitialInstanceCount=1)\n",
145+
"\n",
146+
"kmeans_estimator.setK(10)\n",
147+
"kmeans_estimator.setFeatureDim(784)\n",
148+
"\n",
149+
"# train\n",
150+
"model = kmeans_estimator.fit(trainingData)"
151+
]
152+
},
153+
{
154+
"cell_type": "markdown",
155+
"metadata": {},
156+
"source": [
157+
"## Validate the Model for use\n",
158+
"Now we transform our DataFrame.\n",
159+
"To do this, we serialize each row's \"features\" Vector of Doubles into a Protobuf format for inference against the Amazon SageMaker Endpoint. We deserialize the Protobuf responses back into our DataFrame:"
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": null,
165+
"metadata": {},
166+
"outputs": [],
167+
"source": [
168+
"transformedData = model.transform(testData)\n",
169+
"\n",
170+
"transformedData.show()"
171+
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": null,
176+
"metadata": {},
177+
"outputs": [],
178+
"source": [
179+
"from pyspark.sql.types import DoubleType\n",
180+
"import matplotlib.pyplot as plt\n",
181+
"import numpy as np\n",
182+
"\n",
183+
"# helper function to display a digit\n",
184+
"def show_digit(img, caption='', xlabel='', subplot=None):\n",
185+
" if subplot==None:\n",
186+
" _,(subplot)=plt.subplots(1,1)\n",
187+
" imgr=img.reshape((28,28))\n",
188+
" subplot.axes.get_xaxis().set_ticks([])\n",
189+
" subplot.axes.get_yaxis().set_ticks([])\n",
190+
" plt.title(caption)\n",
191+
" plt.xlabel(xlabel)\n",
192+
" subplot.imshow(imgr, cmap='gray')\n",
193+
"\n",
194+
"images = np.array(transformedData.select(\"features\").cache().take(250))\n",
195+
"clusters = transformedData.select(\"closest_cluster\").cache().take(250)\n",
196+
"\n",
197+
"for cluster in range(10):\n",
198+
" print('\\n\\n\\nCluster {}:'.format(int(cluster)))\n",
199+
" digits = [ img for l, img in zip(clusters, images) if int(l.closest_cluster) == cluster ]\n",
200+
" height=((len(digits)-1)//5)+1\n",
201+
" width=5\n",
202+
" plt.rcParams[\"figure.figsize\"] = (width,height)\n",
203+
" _, subplots = plt.subplots(height, width)\n",
204+
" subplots=np.ndarray.flatten(subplots)\n",
205+
" for subplot, image in zip(subplots, digits):\n",
206+
" show_digit(image, subplot=subplot)\n",
207+
" for subplot in subplots[len(digits):]:\n",
208+
" subplot.axis('off')\n",
209+
"\n",
210+
" plt.show()"
211+
]
212+
},
213+
{
214+
"cell_type": "code",
215+
"execution_count": null,
216+
"metadata": {
217+
"collapsed": true
218+
},
219+
"outputs": [],
220+
"source": [
221+
"# Delete the endpoint\n",
222+
"\n",
223+
"from sagemaker_pyspark import SageMakerResourceCleanup\n",
224+
"\n",
225+
"resource_cleanup = SageMakerResourceCleanup(model.sagemakerClient)\n",
226+
"resource_cleanup.deleteResources(model.getCreatedResources())"
227+
]
228+
},
229+
{
230+
"cell_type": "markdown",
231+
"metadata": {},
232+
"source": [
233+
"## Bring your Own Algorithm\n",
234+
"\n",
235+
"The SageMaker Spark Github repository has more about SageMaker Spark, including how to use SageMaker Spark with your own algorithms on Amazon SageMaker: https://github.com/aws/sagemaker-spark\n"
236+
]
237+
}
238+
],
239+
"metadata": {
240+
"kernelspec": {
241+
"display_name": "conda_python3",
242+
"language": "python",
243+
"name": "conda_python3"
244+
},
245+
"language_info": {
246+
"codemirror_mode": {
247+
"name": "ipython",
248+
"version": 3
249+
},
250+
"file_extension": ".py",
251+
"mimetype": "text/x-python",
252+
"name": "python",
253+
"nbconvert_exporter": "python",
254+
"pygments_lexer": "ipython3",
255+
"version": "3.6.2"
256+
}
257+
},
258+
"nbformat": 4,
259+
"nbformat_minor": 2
260+
}

0 commit comments

Comments
 (0)