Skip to content

Commit 4dc9123

Browse files
author
Dan Choi
committed
Create job base class
1 parent 9b9272b commit 4dc9123

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

src/sagemaker/job.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from abc import abstractmethod
14+
from six import string_types
15+
16+
from sagemaker.session import s3_input
17+
18+
19+
class _Job(object):
20+
def __init__(self, sagemaker_session, training_job_name):
21+
self.sagemaker_session = sagemaker_session
22+
self.job_name = training_job_name
23+
24+
@abstractmethod
25+
def start_new(cls, estimator, inputs):
26+
"""Create a new Amazon SageMaker job from the estimator.
27+
28+
Args:
29+
estimator (sagemaker.estimator.EstimatorBase): Estimator object created by the user.
30+
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
31+
32+
Returns:
33+
sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
34+
"""
35+
36+
pass
37+
38+
@staticmethod
39+
def _load_config(inputs, estimator):
40+
input_config = _Job._format_inputs_to_input_config(inputs)
41+
role = estimator.sagemaker_session.expand_role(estimator.role)
42+
output_config = _Job._prepare_output_config(estimator.output_path,
43+
estimator.output_kms_key)
44+
resource_config = _Job._prepare_resource_config(estimator.train_instance_count,
45+
estimator.train_instance_type,
46+
estimator.train_volume_size)
47+
stop_condition = _Job._prepare_stopping_condition(estimator.train_max_run)
48+
49+
return input_config, role, output_config, resource_config, stop_condition
50+
51+
@staticmethod
52+
def _format_inputs_to_input_config(inputs):
53+
input_dict = {}
54+
if isinstance(inputs, string_types):
55+
input_dict['training'] = _Job._format_s3_uri_input(inputs)
56+
elif isinstance(inputs, s3_input):
57+
input_dict['training'] = inputs
58+
elif isinstance(inputs, dict):
59+
for k, v in inputs.items():
60+
input_dict[k] = _Job._format_s3_uri_input(v)
61+
else:
62+
raise ValueError('Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))
63+
64+
channels = []
65+
for channel_name, channel_s3_input in input_dict.items():
66+
channel_config = channel_s3_input.config.copy()
67+
channel_config['ChannelName'] = channel_name
68+
channels.append(channel_config)
69+
return channels
70+
71+
@staticmethod
72+
def _format_s3_uri_input(input):
73+
if isinstance(input, str):
74+
if not input.startswith('s3://'):
75+
raise ValueError('Training input data must be a valid S3 URI and must start with "s3://"')
76+
return s3_input(input)
77+
if isinstance(input, s3_input):
78+
return input
79+
else:
80+
raise ValueError('Cannot format input {}. Expecting one of str or s3_input'.format(input))
81+
82+
@staticmethod
83+
def _prepare_output_config(s3_path, kms_key_id):
84+
config = {'S3OutputPath': s3_path}
85+
if kms_key_id is not None:
86+
config['KmsKeyId'] = kms_key_id
87+
return config
88+
89+
@staticmethod
90+
def _prepare_resource_config(instance_count, instance_type, volume_size):
91+
resource_config = {'InstanceCount': instance_count,
92+
'InstanceType': instance_type,
93+
'VolumeSizeInGB': volume_size}
94+
return resource_config
95+
96+
@staticmethod
97+
def _prepare_stopping_condition(max_run):
98+
stop_condition = {'MaxRuntimeInSeconds': max_run}
99+
return stop_condition
100+
101+
@property
102+
def name(self):
103+
return self.job_name
104+
105+
def wait(self, logs=True):
106+
if logs:
107+
self.sagemaker_session.logs_for_job(self.job_name, wait=True)
108+
else:
109+
self.sagemaker_session.wait_for_job(self.job_name)

0 commit comments

Comments
 (0)