Skip to content

Commit 6ee0d78

Browse files
Refactor training job class (aws#10)
1 parent 61e44ad commit 6ee0d78

File tree

5 files changed

+223
-231
lines changed

5 files changed

+223
-231
lines changed

src/sagemaker/estimator.py

Lines changed: 13 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717
import os
1818
from abc import ABCMeta
1919
from abc import abstractmethod
20-
from six import with_metaclass, string_types
20+
from six import with_metaclass
2121

2222
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
23-
from sagemaker.local import LocalSession, file_input
24-
23+
from sagemaker.job import _Job
24+
from sagemaker.local import LocalSession
2525
from sagemaker.model import Model
2626
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
2727
CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME)
28-
2928
from sagemaker.predictor import RealTimePredictor
3029
from sagemaker.session import Session
3130
from sagemaker.session import s3_input
@@ -310,10 +309,9 @@ def delete_endpoint(self):
310309
self.sagemaker_session.delete_endpoint(self.latest_training_job.name)
311310

312311

313-
class _TrainingJob(object):
312+
class _TrainingJob(_Job):
314313
def __init__(self, sagemaker_session, training_job_name):
315-
self.sagemaker_session = sagemaker_session
316-
self.job_name = training_job_name
314+
super(_TrainingJob, self).__init__(sagemaker_session, training_job_name)
317315

318316
@classmethod
319317
def start_new(cls, estimator, inputs):
@@ -324,7 +322,8 @@ def start_new(cls, estimator, inputs):
324322
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
325323
326324
Returns:
327-
sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
325+
sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started
326+
training job.
328327
"""
329328

330329
local_mode = estimator.sagemaker_session.local_mode
@@ -334,86 +333,19 @@ def start_new(cls, estimator, inputs):
334333
if not local_mode:
335334
raise ValueError('File URIs are supported in local mode only. Please use a S3 URI instead.')
336335

337-
input_config = _TrainingJob._format_inputs_to_input_config(inputs)
338-
role = estimator.sagemaker_session.expand_role(estimator.role)
339-
output_config = _TrainingJob._prepare_output_config(estimator.output_path, estimator.output_kms_key)
340-
resource_config = _TrainingJob._prepare_resource_config(estimator.train_instance_count,
341-
estimator.train_instance_type,
342-
estimator.train_volume_size)
343-
stop_condition = _TrainingJob._prepare_stopping_condition(estimator.train_max_run)
336+
config = _Job._load_config(inputs, estimator)
344337

345338
if estimator.hyperparameters() is not None:
346339
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
347340

348341
estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
349-
input_config=input_config, role=role, job_name=estimator._current_job_name,
350-
output_config=output_config, resource_config=resource_config,
351-
hyperparameters=hyperparameters, stop_condition=stop_condition)
342+
input_config=config['input_config'], role=config['role'],
343+
job_name=estimator._current_job_name, output_config=config['output_config'],
344+
resource_config=config['resource_config'], hyperparameters=hyperparameters,
345+
stop_condition=config['stop_condition'])
352346

353347
return cls(estimator.sagemaker_session, estimator._current_job_name)
354348

355-
@staticmethod
356-
def _format_inputs_to_input_config(inputs):
357-
input_dict = {}
358-
if isinstance(inputs, string_types):
359-
input_dict['training'] = _TrainingJob._format_string_uri_input(inputs)
360-
elif isinstance(inputs, s3_input):
361-
input_dict['training'] = inputs
362-
elif isinstance(input, file_input):
363-
input_dict['training'] = inputs
364-
elif isinstance(inputs, dict):
365-
for k, v in inputs.items():
366-
input_dict[k] = _TrainingJob._format_string_uri_input(v)
367-
else:
368-
raise ValueError('Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))
369-
370-
channels = []
371-
for channel_name, channel_s3_input in input_dict.items():
372-
channel_config = channel_s3_input.config.copy()
373-
channel_config['ChannelName'] = channel_name
374-
channels.append(channel_config)
375-
return channels
376-
377-
@staticmethod
378-
def _format_string_uri_input(input):
379-
if isinstance(input, str):
380-
if input.startswith('s3://'):
381-
return s3_input(input)
382-
elif input.startswith('file://'):
383-
return file_input(input)
384-
else:
385-
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
386-
'"file://"')
387-
elif isinstance(input, s3_input):
388-
return input
389-
elif isinstance(input, file_input):
390-
return input
391-
else:
392-
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(input))
393-
394-
@staticmethod
395-
def _prepare_output_config(s3_path, kms_key_id):
396-
config = {'S3OutputPath': s3_path}
397-
if kms_key_id is not None:
398-
config['KmsKeyId'] = kms_key_id
399-
return config
400-
401-
@staticmethod
402-
def _prepare_resource_config(instance_count, instance_type, volume_size):
403-
resource_config = {'InstanceCount': instance_count,
404-
'InstanceType': instance_type,
405-
'VolumeSizeInGB': volume_size}
406-
return resource_config
407-
408-
@staticmethod
409-
def _prepare_stopping_condition(max_run):
410-
stop_condition = {'MaxRuntimeInSeconds': max_run}
411-
return stop_condition
412-
413-
@property
414-
def name(self):
415-
return self.job_name
416-
417349
def wait(self, logs=True):
418350
if logs:
419351
self.sagemaker_session.logs_for_job(self.job_name, wait=True)
@@ -474,8 +406,7 @@ def train_image(self):
474406
"""
475407
Returns the docker image to use for training.
476408
477-
The fit() method, that does the model training, calls this method to find the image to use
478-
for model training.
409+
The fit() method, that does the model training, calls this method to find the image to use for model training.
479410
"""
480411
return self.image_name
481412

src/sagemaker/hpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def start_new(cls, tuner, inputs):
9898
resource_config = _TrainingJob._prepare_resource_config(tuner.estimator.train_instance_count,
9999
tuner.estimator.train_instance_type,
100100
tuner.estimator.train_volume_size)
101-
stop_condition = _TrainingJob._prepare_stopping_condition(tuner.estimator.train_max_run)
101+
stop_condition = _TrainingJob._prepare_stop_condition(tuner.estimator.train_max_run)
102102

103103
if tuner.estimator.hyperparameters() is None:
104104
raise ValueError('Cannot tune estimator without hyperparameters')

src/sagemaker/job.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
1315
from abc import abstractmethod
1416
from six import string_types
1517

18+
from sagemaker.local import file_input
1619
from sagemaker.session import s3_input
1720

1821

@@ -55,22 +58,29 @@ def _load_config(inputs, estimator):
5558
resource_config = _Job._prepare_resource_config(estimator.train_instance_count,
5659
estimator.train_instance_type,
5760
estimator.train_volume_size)
58-
stopping_condition = _Job._prepare_stopping_condition(estimator.train_max_run)
61+
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
5962

60-
return input_config, role, output_config, resource_config, stopping_condition
63+
return {'input_config': input_config,
64+
'role': role,
65+
'output_config': output_config,
66+
'resource_config': resource_config,
67+
'stop_condition': stop_condition}
6168

6269
@staticmethod
6370
def _format_inputs_to_input_config(inputs):
6471
input_dict = {}
6572
if isinstance(inputs, string_types):
66-
input_dict['training'] = _Job._format_s3_uri_input(inputs)
73+
input_dict['training'] = _Job._format_string_uri_input(inputs)
6774
elif isinstance(inputs, s3_input):
6875
input_dict['training'] = inputs
76+
elif isinstance(input, file_input):
77+
input_dict['training'] = inputs
6978
elif isinstance(inputs, dict):
7079
for k, v in inputs.items():
71-
input_dict[k] = _Job._format_s3_uri_input(v)
80+
input_dict[k] = _Job._format_string_uri_input(v)
7281
else:
73-
raise ValueError('Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))
82+
raise ValueError(
83+
'Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))
7484

7585
channels = []
7686
for channel_name, channel_s3_input in input_dict.items():
@@ -80,15 +90,24 @@ def _format_inputs_to_input_config(inputs):
8090
return channels
8191

8292
@staticmethod
83-
def _format_s3_uri_input(input):
93+
def _format_string_uri_input(input):
8494
if isinstance(input, str):
85-
if not input.startswith('s3://'):
86-
raise ValueError('Training input data must be a valid S3 URI and must start with "s3://"')
87-
return s3_input(input)
88-
if isinstance(input, s3_input):
95+
if input.startswith('s3://'):
96+
return s3_input(input)
97+
elif input.startswith('file://'):
98+
return file_input(input)
99+
else:
100+
raise ValueError(
101+
'Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
102+
'"file://"')
103+
elif isinstance(input, s3_input):
104+
return input
105+
elif isinstance(input, file_input):
89106
return input
90107
else:
91-
raise ValueError('Cannot format input {}. Expecting one of str or s3_input'.format(input))
108+
raise ValueError(
109+
'Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(
110+
input))
92111

93112
@staticmethod
94113
def _prepare_output_config(s3_path, kms_key_id):
@@ -104,7 +123,7 @@ def _prepare_resource_config(instance_count, instance_type, volume_size):
104123
'VolumeSizeInGB': volume_size}
105124

106125
@staticmethod
107-
def _prepare_stopping_condition(max_run):
126+
def _prepare_stop_condition(max_run):
108127
return {'MaxRuntimeInSeconds': max_run}
109128

110129
@property

0 commit comments

Comments
 (0)