Skip to content

Commit 99eaf6b

Browse files
authored
Add integration tests to run training jobs with sagemaker (#81)
* Add mnist sagemaker tests * Use account-id instead of ecr-image * Merge gpu and cpu sagemaker tests * remove _run_mnist_training
1 parent 3763697 commit 99eaf6b

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

test/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def pytest_addoption(parser):
3636
parser.addoption('--framework-version', default='1.10.0')
3737
parser.addoption('--processor', default='cpu', choices=['gpu', 'cpu'])
3838
parser.addoption('--py-version', default='3', choices=['2', '3'])
39+
parser.addoption('--account-id', default='142577830533')
40+
parser.addoption('--instance-type', default=None)
3941

4042

4143
@pytest.fixture(scope='session')
@@ -80,6 +82,17 @@ def sagemaker_local_session(region):
8082
return LocalSession(boto_session=boto3.Session(region_name=region))
8183

8284

85+
@pytest.fixture(scope='session')
86+
def account_id(request):
87+
return request.config.getoption('--account-id')
88+
89+
90+
@pytest.fixture(scope='session')
91+
def instance_type(request, processor):
92+
return request.config.getoption('--instance-type') or \
93+
'ml.c4.xlarge' if processor == 'cpu' else 'ml.p2.xlarge'
94+
95+
8396
@pytest.fixture(autouse=True)
8497
def skip_by_device_type(request, processor):
8598
is_gpu = (processor == 'gpu')
@@ -91,3 +104,9 @@ def skip_by_device_type(request, processor):
91104
@pytest.fixture(scope='session')
92105
def docker_image(docker_base_name, tag):
93106
return '{}:{}'.format(docker_base_name, tag)
107+
108+
109+
@pytest.fixture(scope='session')
110+
def ecr_image(account_id, docker_base_name, tag, region):
111+
return '{}.dkr.ecr.{}.amazonaws.com/{}:{}'.format(
112+
account_id, region, docker_base_name, tag)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
from sagemaker.tensorflow import TensorFlow
18+
19+
20+
def test_mnist(sagemaker_session, ecr_image, instance_type):
21+
resource_path = os.path.join(os.path.dirname(__file__), '../..', 'resources')
22+
script = os.path.join(resource_path, 'mnist', 'mnist.py')
23+
estimator = TensorFlow(entry_point=script,
24+
role='SageMakerRole',
25+
training_steps=1,
26+
evaluation_steps=1,
27+
train_instance_count=1,
28+
train_instance_type=instance_type,
29+
sagemaker_session=sagemaker_session,
30+
image_name=ecr_image,
31+
base_job_name='test-sagemaker-mnist')
32+
inputs = estimator.sagemaker_session.upload_data(
33+
path=os.path.join(resource_path, 'mnist', 'data'),
34+
key_prefix='scriptmode/mnist')
35+
estimator.fit(inputs)

0 commit comments

Comments
 (0)