Skip to content

Commit 1b60209

Browse files
authored
Add SageMaker integ test for hyperparameter tuning model_dir logic (#183)
1 parent 215179b commit 1b60209

File tree

4 files changed

+74
-5
lines changed

4 files changed

+74
-5
lines changed

docker/1.12.0/Dockerfile.gpu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,11 @@ WORKDIR /
107107
ARG framework_installable
108108
ARG framework_support_installable=sagemaker_tensorflow_container-2.0.0.tar.gz
109109

110-
COPY $framework_installable tensorflow-1.12.0-py2.py3-none-any.whl
110+
COPY $framework_installable tensorflow-1.12.0-py2.py3-none-any.whl
111111
COPY $framework_support_installable .
112112

113113
RUN pip install --no-cache-dir -U \
114114
keras==2.2.4 \
115-
sagemaker-containers==2.4.2 \
116115
$framework_support_installable \
117116
"sagemaker-tensorflow>=1.12,<1.13" \
118117
# Let's install TensorFlow separately in the end to avoid
@@ -129,5 +128,5 @@ RUN pip install --no-cache-dir -U \
129128
RUN ldconfig /usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs && \
130129
HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_WITH_TENSORFLOW=1 pip install --no-cache-dir horovod && \
131130
ldconfig
132-
133-
ENV SAGEMAKER_TRAINING_MODULE sagemaker_tensorflow_container.training:main
131+
132+
ENV SAGEMAKER_TRAINING_MODULE sagemaker_tensorflow_container.training:main

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def read(fname):
4949
'Programming Language :: Python :: 3.6',
5050
],
5151

52-
install_requires=['sagemaker-containers>=2.4.4', 'numpy', 'scipy', 'sklearn',
52+
install_requires=['sagemaker-containers>=2.4.6', 'numpy', 'scipy', 'sklearn',
5353
'pandas', 'Pillow', 'h5py'],
5454
extras_require={
5555
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock',
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
from sagemaker.tensorflow import TensorFlow
18+
from sagemaker.tuner import HyperparameterTuner, IntegerParameter
19+
20+
21+
def test_model_dir_with_training_job_name(sagemaker_session, ecr_image, instance_type, framework_version):
22+
resource_path = os.path.join(os.path.dirname(__file__), '../..', 'resources')
23+
script = os.path.join(resource_path, 'tuning_model_dir', 'entry.py')
24+
25+
estimator = TensorFlow(entry_point=script,
26+
role='SageMakerRole',
27+
train_instance_type=instance_type,
28+
train_instance_count=1,
29+
image_name=ecr_image,
30+
framework_version=framework_version,
31+
py_version='py3',
32+
sagemaker_session=sagemaker_session)
33+
34+
tuner = HyperparameterTuner(estimator=estimator,
35+
objective_metric_name='accuracy',
36+
hyperparameter_ranges={'arbitrary_value': IntegerParameter(0, 1)},
37+
metric_definitions=[{'Name': 'accuracy', 'Regex': 'accuracy=([01])'}],
38+
max_jobs=1,
39+
max_parallel_jobs=1,
40+
base_tuning_job_name='test-tf-tuning-model-dir')
41+
42+
# User script has logic to check for the correct model_dir
43+
tuner.fit()
44+
tuner.wait()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import argparse
16+
import os
17+
18+
parser = argparse.ArgumentParser()
19+
parser.add_argument('--model_dir', type=str)
20+
parser.add_argument('--arbitrary_value', type=int, default=0)
21+
args = parser.parse_args()
22+
23+
assert os.environ['TRAINING_JOB_NAME'] in args.model_dir, 'model_dir not unique to training job: %s' % args.model_dir
24+
25+
# For the "hyperparameter tuning" to work
26+
print('accuracy=1')

0 commit comments

Comments
 (0)