|
68 | 68 | "metadata": {},
|
69 | 69 | "outputs": [],
|
70 | 70 | "source": [
|
71 |
| - "import os\n", |
72 |
| - "import numpy as np\n", |
73 |
| - "import tensorflow as tf\n", |
74 |
| - "\n", |
75 |
| - "INPUT_TENSOR_NAME = 'x'\n", |
76 |
| - "\n", |
77 |
| - "\n", |
78 |
| - "def estimator_fn(run_config, params):\n", |
79 |
| - " feature_columns = [tf.feature_column.numeric_column(INPUT_TENSOR_NAME, shape=[4])]\n", |
80 |
| - " return tf.estimator.DNNClassifier(feature_columns=feature_columns,\n", |
81 |
| - " hidden_units=[10, 20, 10],\n", |
82 |
| - " n_classes=3,\n", |
83 |
| - " config=run_config)\n", |
84 |
| - "\n", |
85 |
| - "\n", |
86 |
| - "def serving_input_fn():\n", |
87 |
| - " feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}\n", |
88 |
| - " return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()\n", |
89 |
| - "\n", |
90 |
| - "\n", |
91 |
| - "def train_input_fn(training_dir, params):\n", |
92 |
| - " \"\"\"Returns input function that would feed the model during training\"\"\"\n", |
93 |
| - " return _generate_input_fn(training_dir, 'iris_training.csv')\n", |
94 |
| - "\n", |
95 |
| - "\n", |
96 |
| - "def eval_input_fn(training_dir, params):\n", |
97 |
| - " \"\"\"Returns input function that would feed the model during evaluation\"\"\"\n", |
98 |
| - " return _generate_input_fn(training_dir, 'iris_test.csv')\n", |
99 |
| - "\n", |
100 |
| - "\n", |
101 |
| - "def _generate_input_fn(training_dir, training_filename):\n", |
102 |
| - " training_set = tf.contrib.learn.datasets.base.load_csv_with_header(\n", |
103 |
| - " filename=os.path.join(training_dir, training_filename),\n", |
104 |
| - " target_dtype=np.int,\n", |
105 |
| - " features_dtype=np.float32)\n", |
106 |
| - "\n", |
107 |
| - " return tf.estimator.inputs.numpy_input_fn(\n", |
108 |
| - " x={INPUT_TENSOR_NAME: np.array(training_set.data)},\n", |
109 |
| - " y=np.array(training_set.target),\n", |
110 |
| - " num_epochs=None,\n", |
111 |
| - " shuffle=True)" |
| 71 | + "!cat iris_dnn_classifier.py" |
112 | 72 | ]
|
113 | 73 | },
|
114 | 74 | {
|
|
124 | 84 | "metadata": {},
|
125 | 85 | "outputs": [],
|
126 | 86 | "source": [
|
| 87 | + "from iris_dnn_classifier import estimator_fn\n", |
127 | 88 | "classifier = estimator_fn(run_config = None, params = None)"
|
128 | 89 | ]
|
129 | 90 | },
|
|
180 | 141 | },
|
181 | 142 | "outputs": [],
|
182 | 143 | "source": [
|
| 144 | + "from iris_dnn_classifier import train_input_fn\n", |
183 | 145 | "train_func = train_input_fn('.', params = None)"
|
184 | 146 | ]
|
185 | 147 | },
|
|
220 | 182 | "metadata": {},
|
221 | 183 | "outputs": [],
|
222 | 184 | "source": [
|
| 185 | + "from iris_dnn_classifier import serving_input_fn\n", |
| 186 | + "\n", |
| 187 | + "\"\"\"from tensorflow.contrib.learn import make_export_strategy \n", |
| 188 | + "export_strategy = make_export_strategy( serving_input_fn = serving_input_fn,\n", |
| 189 | + " exports_to_keep=1)\n", |
| 190 | + "exported_model = export_strategy.export(estimator = classifier, export_path = 'export/Servo/')\"\"\"\n", |
223 | 191 | "exported_model = classifier.export_savedmodel(export_dir_base = 'export/Servo/', \n",
|
224 | 192 | " serving_input_receiver_fn = serving_input_fn)\n",
|
| 193 | + "\n", |
225 | 194 | "print (exported_model)\n",
|
226 | 195 | "import tarfile\n",
|
227 | 196 | "with tarfile.open('model.tar.gz', mode='w:gz') as archive:\n",
|
|
341 | 310 | {
|
342 | 311 | "cell_type": "code",
|
343 | 312 | "execution_count": null,
|
344 |
| - "metadata": { |
345 |
| - "collapsed": true |
346 |
| - }, |
| 313 | + "metadata": {}, |
347 | 314 | "outputs": [],
|
348 | 315 | "source": [
|
349 |
| - "# sagemaker.Session().delete_endpoint(predictor.endpoint)" |
| 316 | + "sagemaker.Session().delete_endpoint(predictor.endpoint)" |
350 | 317 | ]
|
351 | 318 | }
|
352 | 319 | ],
|
353 | 320 | "metadata": {
|
354 |
| - "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.", |
355 | 321 | "kernelspec": {
|
356 |
| - "display_name": "Environment (conda_tensorflow_p36)", |
| 322 | + "display_name": "conda_tensorflow_p27", |
357 | 323 | "language": "python",
|
358 |
| - "name": "conda_tensorflow_p36" |
| 324 | + "name": "conda_tensorflow_p27" |
359 | 325 | },
|
360 | 326 | "language_info": {
|
361 | 327 | "codemirror_mode": {
|
362 | 328 | "name": "ipython",
|
363 |
| - "version": 3 |
| 329 | + "version": 2 |
364 | 330 | },
|
365 | 331 | "file_extension": ".py",
|
366 | 332 | "mimetype": "text/x-python",
|
367 | 333 | "name": "python",
|
368 | 334 | "nbconvert_exporter": "python",
|
369 |
| - "pygments_lexer": "ipython3", |
370 |
| - "version": "3.6.3" |
371 |
| - } |
| 335 | + "pygments_lexer": "ipython2", |
| 336 | + "version": "2.7.11" |
| 337 | + }, |
| 338 | + "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." |
372 | 339 | },
|
373 | 340 | "nbformat": 4,
|
374 | 341 | "nbformat_minor": 2
|
|
0 commit comments