Skip to content

Commit 8841a27

Browse files
authored
Merge pull request aws#362 from icywang86rui/fix-tf-resnet-cifar-10
Fix ResNet CIFAR-10 with tensorboard notebook
2 parents 77e1138 + c3ff6db commit 8841a27

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

sagemaker-python-sdk/tensorflow_resnet_cifar10_with_tensorboard/source_dir/resnet_cifar_10.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,13 @@ def _input_from_files(mode, batch_size, data_dir):
133133
if mode == tf.estimator.ModeKeys.TRAIN:
134134
dataset = dataset.repeat()
135135

136-
dataset = dataset.map(_dataset_parser, num_threads=1,
137-
output_buffer_size=2 * batch_size)
136+
dataset = dataset.map(_dataset_parser, num_parallel_calls=1)
137+
dataset.prefetch(2 * batch_size)
138138

139139
# For training, preprocess the image and shuffle.
140140
if mode == tf.estimator.ModeKeys.TRAIN:
141-
dataset = dataset.map(_train_preprocess_fn, num_threads=1,
142-
output_buffer_size=2 * batch_size)
141+
dataset = dataset.map(_train_preprocess_fn, num_parallel_calls=1)
142+
dataset.prefetch(2 * batch_size)
143143

144144
# Ensure that the capacity is sufficiently large to provide good random
145145
# shuffling.
@@ -149,8 +149,8 @@ def _input_from_files(mode, batch_size, data_dir):
149149
# Subtract off the mean and divide by the variance of the pixels.
150150
dataset = dataset.map(
151151
lambda image, label: (tf.image.per_image_standardization(image), label),
152-
num_threads=1,
153-
output_buffer_size=2 * batch_size)
152+
num_parallel_calls=1)
153+
dataset.prefetch(2 * batch_size)
154154

155155
# Batch results by up to batch_size, and then fetch the tuple from the
156156
# iterator.
@@ -203,7 +203,7 @@ def _dataset_parser(value):
203203
def _record_dataset(filenames):
204204
"""Returns an input pipeline Dataset from `filenames`."""
205205
record_bytes = HEIGHT * WIDTH * DEPTH + 1
206-
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
206+
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
207207

208208

209209
def _filenames(mode, data_dir):

sagemaker-python-sdk/tensorflow_resnet_cifar10_with_tensorboard/tensorflow_resnet_cifar10_with_tensorboard.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
"estimator = TensorFlow(entry_point='resnet_cifar_10.py',\n",
119119
" source_dir=source_dir,\n",
120120
" role=role,\n",
121-
" framework_version='1.6',\n",
121+
" framework_version='1.8',\n",
122122
" hyperparameters={'throttle_secs': 30},\n",
123123
" training_steps=1000, evaluation_steps=100,\n",
124124
" train_instance_count=2, train_instance_type='ml.c4.xlarge', \n",

0 commit comments

Comments
 (0)