Skip to content

Commit 0fbb6d1

Browse files
committed
Updated example to support latest TF framework
1 parent 8211a60 commit 0fbb6d1

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

sagemaker-python-sdk/tensorflow_resnet_cifar10_with_tensorboard/source_dir/resnet_cifar_10.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,13 @@ def _input_from_files(mode, batch_size, data_dir):
133133
if mode == tf.estimator.ModeKeys.TRAIN:
134134
dataset = dataset.repeat()
135135

136-
dataset = dataset.map(_dataset_parser, num_threads=1,
137-
output_buffer_size=2 * batch_size)
136+
dataset = dataset.map(_dataset_parser)
137+
dataset.prefetch(2 * batch_size)
138138

139139
# For training, preprocess the image and shuffle.
140140
if mode == tf.estimator.ModeKeys.TRAIN:
141-
dataset = dataset.map(_train_preprocess_fn, num_threads=1,
142-
output_buffer_size=2 * batch_size)
141+
dataset = dataset.map(_train_preprocess_fn)
142+
dataset.prefetch(2 * batch_size)
143143

144144
# Ensure that the capacity is sufficiently large to provide good random
145145
# shuffling.
@@ -148,9 +148,8 @@ def _input_from_files(mode, batch_size, data_dir):
148148

149149
# Subtract off the mean and divide by the variance of the pixels.
150150
dataset = dataset.map(
151-
lambda image, label: (tf.image.per_image_standardization(image), label),
152-
num_threads=1,
153-
output_buffer_size=2 * batch_size)
151+
lambda image, label: (tf.image.per_image_standardization(image), label))
152+
dataset.prefetch(2 * batch_size)
154153

155154
# Batch results by up to batch_size, and then fetch the tuple from the
156155
# iterator.
@@ -203,7 +202,7 @@ def _dataset_parser(value):
203202
def _record_dataset(filenames):
204203
"""Returns an input pipeline Dataset from `filenames`."""
205204
record_bytes = HEIGHT * WIDTH * DEPTH + 1
206-
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
205+
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
207206

208207

209208
def _filenames(mode, data_dir):

sagemaker-python-sdk/tensorflow_resnet_cifar10_with_tensorboard/tensorflow_resnet_cifar10_with_tensorboard.ipynb

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@
118118
"estimator = TensorFlow(entry_point='resnet_cifar_10.py',\n",
119119
" source_dir=source_dir,\n",
120120
" role=role,\n",
121-
" framework_version='1.6',\n",
122121
" hyperparameters={'throttle_secs': 30},\n",
123122
" training_steps=1000, evaluation_steps=100,\n",
124123
" train_instance_count=2, train_instance_type='ml.c4.xlarge', \n",
@@ -146,9 +145,7 @@
146145
},
147146
{
148147
"cell_type": "markdown",
149-
"metadata": {
150-
"collapsed": true
151-
},
148+
"metadata": {},
152149
"source": [
153150
"# Deploy the trained model to prepare for predictions\n",
154151
"\n",
@@ -225,7 +222,7 @@
225222
"name": "python",
226223
"nbconvert_exporter": "python",
227224
"pygments_lexer": "ipython2",
228-
"version": "2.7.11"
225+
"version": "2.7.14"
229226
},
230227
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
231228
},

0 commit comments

Comments
 (0)