Skip to content

Commit ccaddfe

Browse files
author
Dan Choi
committed
Add unit tests
1 parent 7f8f64e commit ccaddfe

File tree

2 files changed

+171
-6
lines changed

2 files changed

+171
-6
lines changed

src/sagemaker/job.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,13 @@ def wait(self):
5151
def _load_config(inputs, estimator):
5252
input_config = _Job._format_inputs_to_input_config(inputs)
5353
role = estimator.sagemaker_session.expand_role(estimator.role)
54-
output_config = _Job._prepare_output_config(estimator.output_path,
55-
estimator.output_kms_key)
54+
output_config = _Job._prepare_output_config(estimator.output_path, estimator.output_kms_key)
5655
resource_config = _Job._prepare_resource_config(estimator.train_instance_count,
57-
estimator.train_instance_type,
58-
estimator.train_volume_size)
59-
stop_condition = _Job._prepare_stopping_condition(estimator.train_max_run)
56+
estimator.train_instance_type,
57+
estimator.train_volume_size)
58+
stopping_condition = _Job._prepare_stopping_condition(estimator.train_max_run)
6059

61-
return input_config, role, output_config, resource_config, stop_condition
60+
return input_config, role, output_config, resource_config, stopping_condition
6261

6362
@staticmethod
6463
def _format_inputs_to_input_config(inputs):

tests/unit/test_job.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import pytest
14+
from mock import Mock
15+
16+
from sagemaker.estimator import Estimator
17+
from sagemaker.job import _Job
18+
from sagemaker.session import s3_input
19+
20+
BUCKET_NAME = 's3://mybucket/train'
21+
S3_OUTPUT_PATH = 's3://bucket/prefix'
22+
INSTANCE_COUNT = 1
23+
INSTANCE_TYPE = 'c4.4xlarge'
24+
VOLUME_SIZE = 1
25+
MAX_RUNTIME = 1
26+
ROLE = 'DummyRole'
27+
IMAGE_NAME = 'fakeimage'
28+
JOB_NAME = 'fakejob'
29+
30+
31+
@pytest.fixture()
32+
def estimator(sagemaker_session):
33+
return Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, MAX_RUNTIME,
34+
output_path=S3_OUTPUT_PATH, sagemaker_session=sagemaker_session)
35+
36+
37+
@pytest.fixture()
38+
def job(sagemaker_session):
39+
return _Job(sagemaker_session, JOB_NAME)
40+
41+
42+
@pytest.fixture()
43+
def sagemaker_session():
44+
boto_mock = Mock(name='boto_session')
45+
mock_session = Mock(name='sagemaker_session', boto_session=boto_mock)
46+
mock_session.expand_role = Mock(name='expand_role', return_value=ROLE)
47+
48+
return mock_session
49+
50+
51+
def test_load_config(job, estimator):
52+
inputs = s3_input(BUCKET_NAME)
53+
54+
input_config, role, output_config, resource_config, stopping_condition = \
55+
job._load_config(inputs, estimator)
56+
57+
assert input_config[0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME
58+
assert role == ROLE
59+
assert output_config['S3OutputPath'] == S3_OUTPUT_PATH
60+
assert 'KmsKeyId' not in output_config
61+
assert resource_config['InstanceCount'] == INSTANCE_COUNT
62+
assert resource_config['InstanceType'] == INSTANCE_TYPE
63+
assert resource_config['VolumeSizeInGB'] == VOLUME_SIZE
64+
assert stopping_condition['MaxRuntimeInSeconds'] == MAX_RUNTIME
65+
66+
67+
def test_format_inputs_to_input_config_string(job):
68+
inputs = BUCKET_NAME
69+
70+
channels = job._format_inputs_to_input_config(inputs)
71+
72+
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs
73+
74+
75+
def test_format_inputs_to_input_config_s3_input(job):
76+
inputs = s3_input(BUCKET_NAME)
77+
78+
channels = job._format_inputs_to_input_config(inputs)
79+
80+
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs.config['DataSource'][
81+
'S3DataSource']['S3Uri']
82+
83+
84+
def test_format_inputs_to_input_config_dict(job):
85+
inputs = {'train': BUCKET_NAME}
86+
87+
channels = job._format_inputs_to_input_config(inputs)
88+
89+
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs['train']
90+
91+
92+
def test_format_inputs_to_input_config_exception(job):
93+
inputs = 1
94+
95+
with pytest.raises(ValueError):
96+
job._format_inputs_to_input_config(inputs)
97+
98+
99+
def test_format_s3_uri_input_string(job):
100+
inputs = BUCKET_NAME
101+
102+
s3_uri_input = job._format_s3_uri_input(inputs)
103+
104+
assert s3_uri_input.config['DataSource']['S3DataSource']['S3Uri'] == inputs
105+
106+
107+
def test_format_s3_uri_input_string_exception(job):
108+
inputs = 'mybucket/train'
109+
110+
with pytest.raises(ValueError):
111+
job._format_s3_uri_input(inputs)
112+
113+
114+
def test_format_s3_uri_input(job):
115+
inputs = s3_input(BUCKET_NAME)
116+
117+
s3_uri_input = job._format_s3_uri_input(inputs)
118+
119+
assert s3_uri_input.config['DataSource']['S3DataSource']['S3Uri'] == inputs.config[
120+
'DataSource']['S3DataSource']['S3Uri']
121+
122+
123+
def test_format_s3_uri_input_exception(job):
124+
inputs = 1
125+
126+
with pytest.raises(ValueError):
127+
job._format_s3_uri_input(inputs)
128+
129+
130+
def test_prepare_output_config(job):
131+
kms_key_id = 'kms_key'
132+
133+
config = job._prepare_output_config(BUCKET_NAME, kms_key_id)
134+
135+
assert config['S3OutputPath'] == BUCKET_NAME
136+
assert config['KmsKeyId'] == kms_key_id
137+
138+
139+
def test_prepare_output_config_kms_key_none(job):
140+
s3_path = BUCKET_NAME
141+
kms_key_id = None
142+
143+
config = job._prepare_output_config(s3_path, kms_key_id)
144+
145+
assert config['S3OutputPath'] == s3_path
146+
assert 'KmsKeyId' not in config
147+
148+
149+
def test_prepare_resource_config(job):
150+
resource_config = job._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE)
151+
152+
assert resource_config['InstanceCount'] == INSTANCE_COUNT
153+
assert resource_config['InstanceType'] == INSTANCE_TYPE
154+
assert resource_config['VolumeSizeInGB'] == VOLUME_SIZE
155+
156+
157+
def test_prepare_stopping_condition(job):
158+
max_run = 1
159+
160+
stopping_condition = job._prepare_stopping_condition(max_run)
161+
162+
assert stopping_condition['MaxRuntimeInSeconds'] == max_run
163+
164+
165+
def test_name(job):
166+
assert job.name == JOB_NAME

0 commit comments

Comments
 (0)