Skip to content

Commit 767e58a

Browse files
authored
Merge pull request aws#23 from awslabs/tf_byom
Tensorflow Bring Your Own Model. Merging in preparation for tomorrow's meeting.
2 parents 826f254 + c16d216 commit 767e58a

File tree

5 files changed

+626
-0
lines changed

5 files changed

+626
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ This repository contains example notebooks that show how to apply machine learni
1818
- [Installing the R Kernel](install_r_kernel) shows how to install the R kernel into an Amazon SageMaker Notebook Instance.
1919
- [Bring Your Own Model for k-means](kmeans_bring_your_own_model) shows how to take a model that's been fit elsewhere and use Amazon SageMaker containers to host.
2020
- [Bring Your Own Algorithm with R](r_bring_your_own) shows how to bring your own algorithm container to Amazon SageMaker using the R language.
21+
- [Bring Your Own Tensorflow Model](sagemaker-python-sdk/tensorflow_iris_byom) shows how to bring a model trained anywhere into Amazon SageMaker
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import logging
2+
3+
import gzip
4+
import mxnet as mx
5+
import numpy as np
6+
import os
7+
import struct
8+
9+
10+
def find_file(root_path, file_name):
11+
for root, dirs, files in os.walk(root_path):
12+
if file_name in files:
13+
return os.path.join(root, file_name)
14+
15+
16+
def build_graph():
17+
data = mx.sym.var('data')
18+
data = mx.sym.flatten(data=data)
19+
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
20+
act1 = mx.sym.Activation(data=fc1, act_type="relu")
21+
fc2 = mx.sym.FullyConnected(data=act1, num_hidden=64)
22+
act2 = mx.sym.Activation(data=fc2, act_type="relu")
23+
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
24+
return mx.sym.SoftmaxOutput(data=fc3, name='softmax')
25+
26+
27+
def train(data, hyperparameters= {'learning_rate': 0.11}, num_cpus=0, num_gpus =1 , **kwargs):
28+
train_labels = data['train_label']
29+
train_images = data['train_data']
30+
test_labels = data['test_label']
31+
test_images = data['test_data']
32+
batch_size = 100
33+
train_iter = mx.io.NDArrayIter(train_images, train_labels, batch_size, shuffle=True)
34+
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
35+
logging.getLogger().setLevel(logging.DEBUG)
36+
mlp_model = mx.mod.Module(
37+
symbol=build_graph(),
38+
context=get_train_context(num_cpus, num_gpus))
39+
mlp_model.fit(train_iter,
40+
eval_data=val_iter,
41+
optimizer='sgd',
42+
optimizer_params={'learning_rate': float(hyperparameters.get("learning_rate", 0.1))},
43+
eval_metric='acc',
44+
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
45+
num_epoch=10)
46+
return mlp_model
47+
48+
49+
def get_train_context(num_cpus, num_gpus):
50+
if num_gpus > 0:
51+
return mx.gpu()
52+
return mx.cpu()
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Mxnet MNIST BYOM. Train locally and deploy on SageMaker."
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"In this notebook, we will train a model locally on the notebook instance and will deploy and predict from Sagemaker. This can easily be extended to a model trained anywhere else as well. All that is needed is the exported model file and the entry point file containing model definitions. \n",
15+
"\n",
16+
"First, let us begin by downloading the mnist data using the mxnet utilities."
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {
23+
"collapsed": true
24+
},
25+
"outputs": [],
26+
"source": [
27+
"import mxnet as mx\n",
28+
"data = mx.test_utils.get_mnist()"
29+
]
30+
},
31+
{
32+
"cell_type": "markdown",
33+
"metadata": {
34+
"collapsed": true
35+
},
36+
"source": [
37+
"Train a typical mxnet model for lenet."
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"metadata": {
44+
"collapsed": true
45+
},
46+
"outputs": [],
47+
"source": [
48+
"from mnist import train\n",
49+
"model = train(data = data)"
50+
]
51+
},
52+
{
53+
"cell_type": "markdown",
54+
"metadata": {},
55+
"source": [
56+
"Export the model and save it down. Analogous to the tensorflow example, some structure needs to be followed, which is explained in the following code."
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": null,
62+
"metadata": {
63+
"collapsed": true
64+
},
65+
"outputs": [],
66+
"source": [
67+
"import os\n",
68+
"os.mkdir('model')\n",
69+
"model.save_checkpoint('model/model', 0000)\n",
70+
"import tarfile\n",
71+
"with tarfile.open('model.tar.gz', mode='w:gz') as archive:\n",
72+
" archive.add('model', recursive=True)"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"Open a sagemaker session and upload the model on to the default S3 bucket."
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {
86+
"collapsed": true
87+
},
88+
"outputs": [],
89+
"source": [
90+
"import sagemaker\n",
91+
"\n",
92+
"sagemaker_session = sagemaker.Session()\n",
93+
"inputs = sagemaker_session.upload_data(path='model.tar.gz', key_prefix='model')"
94+
]
95+
},
96+
{
97+
"cell_type": "markdown",
98+
"metadata": {},
99+
"source": [
100+
"Use the ``sagemaker.mxnet.model.MXNetModel`` to create a new model that can be deployed."
101+
]
102+
},
103+
{
104+
"cell_type": "code",
105+
"execution_count": null,
106+
"metadata": {
107+
"collapsed": true
108+
},
109+
"outputs": [],
110+
"source": [
111+
"from sagemaker.mxnet.model import MXNetModel\n",
112+
"sagemaker_model = MXNetModel(model_data = 's3://' + sagemaker_session.default_bucket() + '/model/model.tar.gz',\n",
113+
" role = '<<set role here>>',\n",
114+
" entry_point = 'mnist.py')"
115+
]
116+
},
117+
{
118+
"cell_type": "markdown",
119+
"metadata": {},
120+
"source": [
121+
"Deploy the model"
122+
]
123+
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": null,
127+
"metadata": {
128+
"collapsed": true
129+
},
130+
"outputs": [],
131+
"source": [
132+
"predictor = sagemaker_model.deploy(initial_instance_count=1,\n",
133+
" instance_type='ml.c4.xlarge')"
134+
]
135+
},
136+
{
137+
"cell_type": "markdown",
138+
"metadata": {},
139+
"source": [
140+
"We can now use this predictor to classify hand-written digits."
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": null,
146+
"metadata": {
147+
"collapsed": true
148+
},
149+
"outputs": [],
150+
"source": [
151+
"predict_sample = data['test_data'][0][0]\n",
152+
"response = predictor.predict(data)\n",
153+
"print('Raw prediction result:')\n",
154+
"print(response)"
155+
]
156+
},
157+
{
158+
"cell_type": "markdown",
159+
"metadata": {
160+
"collapsed": true
161+
},
162+
"source": [
163+
"(Optional) Delete the Endpoint"
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": null,
169+
"metadata": {
170+
"collapsed": true
171+
},
172+
"outputs": [],
173+
"source": [
174+
"print(predictor.endpoint)"
175+
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": null,
180+
"metadata": {
181+
"collapsed": true
182+
},
183+
"outputs": [],
184+
"source": [
185+
"import sagemaker\n",
186+
"\n",
187+
"sagemaker.Session().delete_endpoint(predictor.endpoint)"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {
194+
"collapsed": true
195+
},
196+
"outputs": [],
197+
"source": [
198+
"os.remove('model.tar.gz')\n",
199+
"import shutil\n",
200+
"shutil.rmtree('export')"
201+
]
202+
}
203+
],
204+
"metadata": {
205+
"kernelspec": {
206+
"display_name": "Python 3",
207+
"language": "python",
208+
"name": "python3"
209+
},
210+
"language_info": {
211+
"codemirror_mode": {
212+
"name": "ipython",
213+
"version": 3
214+
},
215+
"file_extension": ".py",
216+
"mimetype": "text/x-python",
217+
"name": "python",
218+
"nbconvert_exporter": "python",
219+
"pygments_lexer": "ipython3",
220+
"version": "3.6.3"
221+
}
222+
},
223+
"nbformat": 4,
224+
"nbformat_minor": 2
225+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
INPUT_TENSOR_NAME = 'x'
6+
7+
8+
def estimator_fn(run_config, params):
9+
feature_columns = [tf.feature_column.numeric_column(INPUT_TENSOR_NAME, shape=[4])]
10+
return tf.estimator.DNNClassifier(feature_columns=feature_columns,
11+
hidden_units=[10, 20, 10],
12+
n_classes=3,
13+
config=run_config)
14+
15+
16+
def serving_input_fn():
17+
feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}
18+
return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()
19+
20+
21+
def train_input_fn(training_dir, params):
22+
"""Returns input function that would feed the model during training"""
23+
return _generate_input_fn(training_dir, 'iris_training.csv')
24+
25+
26+
def eval_input_fn(training_dir, params):
27+
"""Returns input function that would feed the model during evaluation"""
28+
return _generate_input_fn(training_dir, 'iris_test.csv')
29+
30+
31+
def _generate_input_fn(training_dir, training_filename):
32+
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
33+
filename=os.path.join(training_dir, training_filename),
34+
target_dtype=np.int,
35+
features_dtype=np.float32)
36+
37+
return tf.estimator.inputs.numpy_input_fn(
38+
x={INPUT_TENSOR_NAME: np.array(training_set.data)},
39+
y=np.array(training_set.target),
40+
num_epochs=None,
41+
shuffle=True)

0 commit comments

Comments
 (0)