Skip to content

Commit 1aa7659

Browse files
authored
fix: add hyperparameter tuning test (#216)
1 parent 30ac4fd commit 1aa7659

File tree

3 files changed

+89
-3
lines changed

3 files changed

+89
-3
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
2-
application_import_names = sagemaker_tensorflow_container, test, utils
2+
application_import_names = sagemaker_tensorflow_container, test, timeout, utils
33
import-order-style = google

test/integration/sagemaker/test_mnist.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
import boto3
1818
import pytest
1919
from sagemaker.tensorflow import TensorFlow
20+
from sagemaker.tuner import HyperparameterTuner, IntegerParameter
2021
from six.moves.urllib.parse import urlparse
2122

2223
from test.integration.utils import processor, py_version, unique_name_from_base # noqa: F401
24+
from timeout import timeout
2325

2426

2527
@pytest.mark.deploy_test
2628
def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):
27-
resource_path = os.path.join(os.path.dirname(__file__), '../..', 'resources')
29+
resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
2830
script = os.path.join(resource_path, 'mnist', 'mnist.py')
2931
estimator = TensorFlow(entry_point=script,
3032
role='SageMakerRole',
@@ -42,7 +44,7 @@ def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):
4244

4345

4446
def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type, framework_version):
45-
resource_path = os.path.join(os.path.dirname(__file__), '../..', 'resources')
47+
resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
4648
script = os.path.join(resource_path, 'mnist', 'mnist.py')
4749
estimator = TensorFlow(entry_point=script,
4850
role='SageMakerRole',
@@ -110,6 +112,40 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framewor
110112
_assert_checkpoint_exists(region, estimator.model_dir, 200)
111113

112114

115+
def test_tuning(sagemaker_session, ecr_image, instance_type, framework_version):
116+
resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
117+
script = os.path.join(resource_path, 'mnist', 'mnist.py')
118+
119+
estimator = TensorFlow(entry_point=script,
120+
role='SageMakerRole',
121+
train_instance_type=instance_type,
122+
train_instance_count=1,
123+
sagemaker_session=sagemaker_session,
124+
image_name=ecr_image,
125+
framework_version=framework_version,
126+
script_mode=True)
127+
128+
hyperparameter_ranges = {'epochs': IntegerParameter(1, 2)}
129+
objective_metric_name = 'accuracy'
130+
metric_definitions = [{'Name': objective_metric_name, 'Regex': 'accuracy = ([0-9\\.]+)'}]
131+
132+
tuner = HyperparameterTuner(estimator,
133+
objective_metric_name,
134+
hyperparameter_ranges,
135+
metric_definitions,
136+
max_jobs=2,
137+
max_parallel_jobs=2)
138+
139+
with timeout(minutes=20):
140+
inputs = estimator.sagemaker_session.upload_data(
141+
path=os.path.join(resource_path, 'mnist', 'data'),
142+
key_prefix='scriptmode/mnist')
143+
144+
tuning_job_name = unique_name_from_base('test-tf-sm-tuning', max_length=32)
145+
tuner.fit(inputs, job_name=tuning_job_name)
146+
tuner.wait()
147+
148+
113149
def _assert_checkpoint_exists(region, model_dir, checkpoint_number):
114150
_assert_s3_file_exists(region, os.path.join(model_dir, 'graph.pbtxt'))
115151
_assert_s3_file_exists(region,

test/integration/sagemaker/timeout.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from contextlib import contextmanager
16+
import logging
17+
import signal
18+
19+
LOGGER = logging.getLogger('timeout')
20+
21+
22+
class TimeoutError(Exception):
23+
pass
24+
25+
26+
@contextmanager
27+
def timeout(seconds=0, minutes=0, hours=0):
28+
"""Add a signal-based timeout to any block of code.
29+
If multiple time units are specified, they will be added together to determine time limit.
30+
Usage:
31+
with timeout(seconds=5):
32+
my_slow_function(...)
33+
Args:
34+
- seconds: The time limit, in seconds.
35+
- minutes: The time limit, in minutes.
36+
- hours: The time limit, in hours.
37+
"""
38+
39+
limit = seconds + 60 * minutes + 3600 * hours
40+
41+
def handler(signum, frame):
42+
raise TimeoutError('timed out after {} seconds'.format(limit))
43+
44+
try:
45+
signal.signal(signal.SIGALRM, handler)
46+
signal.alarm(limit)
47+
48+
yield
49+
finally:
50+
signal.alarm(0)

0 commit comments

Comments
 (0)