Skip to content

fix: add hyperparameter tuning test #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
application_import_names = sagemaker_tensorflow_container, test, utils
application_import_names = sagemaker_tensorflow_container, test, timeout, utils
import-order-style = google
40 changes: 38 additions & 2 deletions test/integration/sagemaker/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
import boto3
import pytest
from sagemaker.tensorflow import TensorFlow
from sagemaker.tuner import HyperparameterTuner, IntegerParameter
from six.moves.urllib.parse import urlparse

from test.integration.utils import processor, py_version, unique_name_from_base # noqa: F401
from timeout import timeout


@pytest.mark.deploy_test
def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):
resource_path = os.path.join(os.path.dirname(__file__), '../..', 'resources')
resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
script = os.path.join(resource_path, 'mnist', 'mnist.py')
estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
Expand All @@ -42,7 +44,7 @@ def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):


def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type, framework_version):
resource_path = os.path.join(os.path.dirname(__file__), '../..', 'resources')
resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
script = os.path.join(resource_path, 'mnist', 'mnist.py')
estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
Expand Down Expand Up @@ -110,6 +112,40 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framewor
_assert_checkpoint_exists(region, estimator.model_dir, 200)


def test_tuning(sagemaker_session, ecr_image, instance_type, framework_version):
resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
script = os.path.join(resource_path, 'mnist', 'mnist.py')

estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
train_instance_type=instance_type,
train_instance_count=1,
sagemaker_session=sagemaker_session,
image_name=ecr_image,
framework_version=framework_version,
script_mode=True)

hyperparameter_ranges = {'epochs': IntegerParameter(1, 2)}
objective_metric_name = 'accuracy'
metric_definitions = [{'Name': objective_metric_name, 'Regex': 'accuracy = ([0-9\\.]+)'}]

tuner = HyperparameterTuner(estimator,
objective_metric_name,
hyperparameter_ranges,
metric_definitions,
max_jobs=2,
max_parallel_jobs=2)

with timeout(minutes=20):
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(resource_path, 'mnist', 'data'),
key_prefix='scriptmode/mnist')

tuning_job_name = unique_name_from_base('test-tf-sm-tuning', max_length=32)
tuner.fit(inputs, job_name=tuning_job_name)
tuner.wait()


def _assert_checkpoint_exists(region, model_dir, checkpoint_number):
_assert_s3_file_exists(region, os.path.join(model_dir, 'graph.pbtxt'))
_assert_s3_file_exists(region,
Expand Down
50 changes: 50 additions & 0 deletions test/integration/sagemaker/timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import absolute_import

from contextlib import contextmanager
import logging
import signal

LOGGER = logging.getLogger('timeout')


class TimeoutError(Exception):
pass


@contextmanager
def timeout(seconds=0, minutes=0, hours=0):
"""Add a signal-based timeout to any block of code.
If multiple time units are specified, they will be added together to determine time limit.
Usage:
with timeout(seconds=5):
my_slow_function(...)
Args:
- seconds: The time limit, in seconds.
- minutes: The time limit, in minutes.
- hours: The time limit, in hours.
"""

limit = seconds + 60 * minutes + 3600 * hours

def handler(signum, frame):
raise TimeoutError('timed out after {} seconds'.format(limit))

try:
signal.signal(signal.SIGALRM, handler)
signal.alarm(limit)

yield
finally:
signal.alarm(0)