Skip to content

Commit 85cded3

Browse files
laurenyuicywang86rui
authored andcommitted
Update integ test for checking Python version (#189)
* Update integ test for checking Python version * remove emacs backup file * remove extraneous code * Python 2 compatibility
1 parent c097ca1 commit 85cded3

File tree

3 files changed

+40
-126
lines changed

3 files changed

+40
-126
lines changed

test/integration/docker_utils.py

Lines changed: 0 additions & 119 deletions
This file was deleted.

test/integration/local/test_training.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -18,8 +18,6 @@
1818
import pytest
1919
from sagemaker.tensorflow import TensorFlow
2020

21-
from test.integration.docker_utils import Container
22-
2321
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
2422
TF_CHECKPOINT_FILES = ['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta']
2523

@@ -32,10 +30,23 @@ def py_full_version(py_version):
3230
return '3.6'
3331

3432

35-
def test_py_versions(docker_image, processor, py_full_version):
36-
with Container(docker_image, processor) as c:
37-
output = c.execute_command(['python', '--version'])
38-
assert output.strip().startswith('Python {}'.format(py_full_version))
33+
def test_py_versions(sagemaker_local_session, docker_image, py_full_version, framework_version, tmpdir):
34+
output_path = 'file://{}'.format(tmpdir)
35+
run_tf_training(script=os.path.join(RESOURCE_PATH, 'test_py_version', 'entry.py'),
36+
instance_type='local',
37+
instance_count=1,
38+
sagemaker_local_session=sagemaker_local_session,
39+
docker_image=docker_image,
40+
framework_version=framework_version,
41+
output_path=output_path,
42+
training_data_path=None)
43+
44+
with tarfile.open(os.path.join(str(tmpdir), 'output.tar.gz')) as tar:
45+
output_file = tar.getmember('py_version')
46+
tar.extractall(path=str(tmpdir), members=[output_file])
47+
48+
with open(os.path.join(str(tmpdir), 'py_version')) as f:
49+
assert f.read().strip() == py_full_version
3950

4051

4152
@pytest.mark.skip_gpu
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import sys
17+
18+
19+
py_version = '%s.%s' % (sys.version_info.major, sys.version_info.minor)
20+
21+
with open(os.path.join(os.environ['SM_OUTPUT_DIR'], 'py_version'), 'a') as f:
22+
f.write(py_version)

0 commit comments

Comments
 (0)