Skip to content

Commit a4059a7

Browse files
authored
Upgrade to TF 1.12 (#31)
* Upgrade to TF 1.12
1 parent 8085512 commit a4059a7

File tree

6 files changed

+24
-6
lines changed

6 files changed

+24
-6
lines changed

create_integ_test_docker_images.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import glob
1010
import sys
1111

12-
TF_VERSION = "1.11.0"
12+
TF_VERSION = "1.12.0"
1313
REGION = "us-west-2"
1414

1515
if __name__ == '__main__':

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def build_extension(self, ext):
9494

9595
setup(
9696
name='sagemaker_tensorflow',
97-
version='1.11.0.1.0.0',
97+
version='1.12.0.1.0.0',
9898
description='Amazon Sagemaker specific TensorFlow extensions.',
9999
packages=find_packages(where='src', exclude=('test',)),
100100
package_dir={'': 'src'},

src/sagemaker_tensorflow/pipemode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def __init__(self, channel, record_format='RecordIO',
7171
def _as_variant_tensor(self):
7272
return self._tf_plugin.pipe_mode_dataset(self.benchmark, self.record_format, self.state_dir, self.channel,
7373
self.pipe_dir)
74+
75+
def _inputs(self):
76+
return []
77+
7478
def _validate_input_data_config(self):
7579
if self.channel not in self.input_data_config:
7680
raise PipeModeDatasetException("Channel {} not found in Training Job InputDataConfig".format(self.channel))

test/integ/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
FROM ubuntu:16.04
22

33
ARG device=cpu
4-
ARG tensorflow_version=1.9.0
4+
ARG tensorflow_version=1.12.0
55
ARG script
66
ARG python
77

test/integ/scripts/estimator_script.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import tensorflow as tf
77
from sagemaker_tensorflow import PipeModeDataset
88

9-
ds = PipeModeDataset("elizabeth")
9+
print("Starting estimator script")
1010

11+
ds = PipeModeDataset("elizabeth")
1112

1213
class BenchmarkConfig(object):
1314

@@ -74,24 +75,37 @@ def parse(record):
7475
model_dir = tempfile.mkdtemp()
7576
estimator = tf.estimator.LinearClassifier(feature_columns=[column])
7677

78+
print("About to call train")
7779
estimator.train(input_fn=input_fn)
7880

7981
# Confirm that we have read the correct number of pipes
8082
assert os.path.exists('/opt/ml/input/data/{}_{}'.format(config.channel, config.epochs + 1))
8183

84+
print("About to call evaluate")
85+
result = estimator.evaluate(input_fn=input_fn)
86+
for key,value in sorted(result.items()):
87+
print('%s: %s' % (key, value))
88+
89+
8290
# Test that we can create a new PipeModeDataset after training has run
91+
print("Validate that new PipeModeDataset on existing channel can be created")
92+
8393
ds = PipeModeDataset(config.channel)
8494

8595
with tf.Session() as sess:
8696
it = ds.make_one_shot_iterator()
8797
next = it.get_next()
8898
sess.run(next)
8999

100+
print("Validate create, read, discard, recreate")
101+
90102
# Test that we can create a PipeModeDataset, discard it, and read from a new one
91103
ds = PipeModeDataset(config.channel)
92104
with tf.Session() as sess:
93105
it = ds.make_one_shot_iterator()
94106
next = it.get_next()
107+
108+
95109
ds = PipeModeDataset(config.channel)
96110
with tf.Session() as sess:
97111
it = ds.make_one_shot_iterator()

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ deps =
4949
mock
5050
contextlib2
5151
teamcity-messages
52-
tensorflow==1.11
52+
tensorflow==1.12
5353
awslogs
5454
docker
5555
cmake
@@ -60,5 +60,5 @@ deps =
6060
cmake
6161
flake8
6262
flake8-future-import
63-
tensorflow==1.11
63+
tensorflow==1.12
6464
commands = flake8

0 commit comments

Comments
 (0)