Skip to content

Commit a8f75d7

Browse files
author
Ragav Venkatesan
authored
Some change in notes and writing.
Importing the methods now instead of directly on the notebook.
1 parent 2bacd29 commit a8f75d7

File tree

1 file changed

+19
-52
lines changed

1 file changed

+19
-52
lines changed

under_development/tensorflow_iris_byom/tensorflow_BYOM_iris.ipynb

Lines changed: 19 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -68,47 +68,7 @@
6868
"metadata": {},
6969
"outputs": [],
7070
"source": [
71-
"import os\n",
72-
"import numpy as np\n",
73-
"import tensorflow as tf\n",
74-
"\n",
75-
"INPUT_TENSOR_NAME = 'x'\n",
76-
"\n",
77-
"\n",
78-
"def estimator_fn(run_config, params):\n",
79-
" feature_columns = [tf.feature_column.numeric_column(INPUT_TENSOR_NAME, shape=[4])]\n",
80-
" return tf.estimator.DNNClassifier(feature_columns=feature_columns,\n",
81-
" hidden_units=[10, 20, 10],\n",
82-
" n_classes=3,\n",
83-
" config=run_config)\n",
84-
"\n",
85-
"\n",
86-
"def serving_input_fn():\n",
87-
" feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}\n",
88-
" return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()\n",
89-
"\n",
90-
"\n",
91-
"def train_input_fn(training_dir, params):\n",
92-
" \"\"\"Returns input function that would feed the model during training\"\"\"\n",
93-
" return _generate_input_fn(training_dir, 'iris_training.csv')\n",
94-
"\n",
95-
"\n",
96-
"def eval_input_fn(training_dir, params):\n",
97-
" \"\"\"Returns input function that would feed the model during evaluation\"\"\"\n",
98-
" return _generate_input_fn(training_dir, 'iris_test.csv')\n",
99-
"\n",
100-
"\n",
101-
"def _generate_input_fn(training_dir, training_filename):\n",
102-
" training_set = tf.contrib.learn.datasets.base.load_csv_with_header(\n",
103-
" filename=os.path.join(training_dir, training_filename),\n",
104-
" target_dtype=np.int,\n",
105-
" features_dtype=np.float32)\n",
106-
"\n",
107-
" return tf.estimator.inputs.numpy_input_fn(\n",
108-
" x={INPUT_TENSOR_NAME: np.array(training_set.data)},\n",
109-
" y=np.array(training_set.target),\n",
110-
" num_epochs=None,\n",
111-
" shuffle=True)"
71+
"!cat iris_dnn_classifier.py"
11272
]
11373
},
11474
{
@@ -124,6 +84,7 @@
12484
"metadata": {},
12585
"outputs": [],
12686
"source": [
87+
"from iris_dnn_classifier import estimator_fn\n",
12788
"classifier = estimator_fn(run_config = None, params = None)"
12889
]
12990
},
@@ -180,6 +141,7 @@
180141
},
181142
"outputs": [],
182143
"source": [
144+
"from iris_dnn_classifier import train_input_fn\n",
183145
"train_func = train_input_fn('.', params = None)"
184146
]
185147
},
@@ -220,8 +182,15 @@
220182
"metadata": {},
221183
"outputs": [],
222184
"source": [
185+
"from iris_dnn_classifier import serving_input_fn\n",
186+
"\n",
187+
"\"\"\"from tensorflow.contrib.learn import make_export_strategy \n",
188+
"export_strategy = make_export_strategy( serving_input_fn = serving_input_fn,\n",
189+
" exports_to_keep=1)\n",
190+
"exported_model = export_strategy.export(estimator = classifier, export_path = 'export/Servo/')\"\"\"\n",
223191
"exported_model = classifier.export_savedmodel(export_dir_base = 'export/Servo/', \n",
224192
" serving_input_receiver_fn = serving_input_fn)\n",
193+
"\n",
225194
"print (exported_model)\n",
226195
"import tarfile\n",
227196
"with tarfile.open('model.tar.gz', mode='w:gz') as archive:\n",
@@ -341,34 +310,32 @@
341310
{
342311
"cell_type": "code",
343312
"execution_count": null,
344-
"metadata": {
345-
"collapsed": true
346-
},
313+
"metadata": {},
347314
"outputs": [],
348315
"source": [
349-
"# sagemaker.Session().delete_endpoint(predictor.endpoint)"
316+
"sagemaker.Session().delete_endpoint(predictor.endpoint)"
350317
]
351318
}
352319
],
353320
"metadata": {
354-
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.",
355321
"kernelspec": {
356-
"display_name": "Environment (conda_tensorflow_p36)",
322+
"display_name": "conda_tensorflow_p27",
357323
"language": "python",
358-
"name": "conda_tensorflow_p36"
324+
"name": "conda_tensorflow_p27"
359325
},
360326
"language_info": {
361327
"codemirror_mode": {
362328
"name": "ipython",
363-
"version": 3
329+
"version": 2
364330
},
365331
"file_extension": ".py",
366332
"mimetype": "text/x-python",
367333
"name": "python",
368334
"nbconvert_exporter": "python",
369-
"pygments_lexer": "ipython3",
370-
"version": "3.6.3"
371-
}
335+
"pygments_lexer": "ipython2",
336+
"version": "2.7.11"
337+
},
338+
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
372339
},
373340
"nbformat": 4,
374341
"nbformat_minor": 2

0 commit comments

Comments
 (0)