Skip to content

Commit 4153e41

Browse files
authored
Add airflow module that could export training config (#480)
1 parent e0944e0 commit 4153e41

File tree

6 files changed

+649
-29
lines changed

6 files changed

+649
-29
lines changed

src/sagemaker/job.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def wait(self):
5151
pass
5252

5353
@staticmethod
54-
def _load_config(inputs, estimator):
55-
input_config = _Job._format_inputs_to_input_config(inputs)
56-
role = estimator.sagemaker_session.expand_role(estimator.role)
54+
def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
55+
input_config = _Job._format_inputs_to_input_config(inputs, validate_uri)
56+
role = estimator.sagemaker_session.expand_role(estimator.role) if expand_role else estimator.role
5757
output_config = _Job._prepare_output_config(estimator.output_path, estimator.output_kms_key)
5858
resource_config = _Job._prepare_resource_config(estimator.train_instance_count,
5959
estimator.train_instance_type,
@@ -62,7 +62,8 @@ def _load_config(inputs, estimator):
6262
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
6363
vpc_config = estimator.get_vpc_config()
6464

65-
model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name)
65+
model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name,
66+
validate_uri)
6667
if model_channel:
6768
input_config = [] if input_config is None else input_config
6869
input_config.append(model_channel)
@@ -75,7 +76,7 @@ def _load_config(inputs, estimator):
7576
'vpc_config': vpc_config}
7677

7778
@staticmethod
78-
def _format_inputs_to_input_config(inputs):
79+
def _format_inputs_to_input_config(inputs, validate_uri=True):
7980
if inputs is None:
8081
return None
8182

@@ -86,14 +87,14 @@ def _format_inputs_to_input_config(inputs):
8687

8788
input_dict = {}
8889
if isinstance(inputs, string_types):
89-
input_dict['training'] = _Job._format_string_uri_input(inputs)
90+
input_dict['training'] = _Job._format_string_uri_input(inputs, validate_uri)
9091
elif isinstance(inputs, s3_input):
9192
input_dict['training'] = inputs
9293
elif isinstance(inputs, file_input):
9394
input_dict['training'] = inputs
9495
elif isinstance(inputs, dict):
9596
for k, v in inputs.items():
96-
input_dict[k] = _Job._format_string_uri_input(v)
97+
input_dict[k] = _Job._format_string_uri_input(v, validate_uri)
9798
elif isinstance(inputs, list):
9899
input_dict = _Job._format_record_set_list_input(inputs)
99100
else:
@@ -111,15 +112,16 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
111112
return channel_config
112113

113114
@staticmethod
114-
def _format_string_uri_input(uri_input):
115-
if isinstance(uri_input, str):
116-
if uri_input.startswith('s3://'):
117-
return s3_input(uri_input)
118-
elif uri_input.startswith('file://'):
119-
return file_input(uri_input)
120-
else:
121-
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
122-
'"file://"')
115+
def _format_string_uri_input(uri_input, validate_uri=True):
116+
if isinstance(uri_input, str) and validate_uri and uri_input.startswith('s3://'):
117+
return s3_input(uri_input)
118+
elif isinstance(uri_input, str) and validate_uri and uri_input.startswith('file://'):
119+
return file_input(uri_input)
120+
elif isinstance(uri_input, str) and validate_uri:
121+
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
122+
'"file://"')
123+
elif isinstance(uri_input, str):
124+
return s3_input(uri_input)
123125
elif isinstance(uri_input, s3_input):
124126
return uri_input
125127
elif isinstance(uri_input, file_input):
@@ -128,7 +130,7 @@ def _format_string_uri_input(uri_input):
128130
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(uri_input))
129131

130132
@staticmethod
131-
def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None):
133+
def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None, validate_uri=True):
132134
if not model_uri:
133135
return
134136
elif not model_channel_name:
@@ -139,22 +141,24 @@ def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None
139141
if channel['ChannelName'] == model_channel_name:
140142
raise ValueError('Duplicate channels not allowed.')
141143

142-
model_input = _Job._format_model_uri_input(model_uri)
144+
model_input = _Job._format_model_uri_input(model_uri, validate_uri)
143145
model_channel = _Job._convert_input_to_channel(model_channel_name, model_input)
144146

145147
return model_channel
146148

147149
@staticmethod
148-
def _format_model_uri_input(model_uri):
149-
if isinstance(model_uri, string_types):
150-
if model_uri.startswith('s3://'):
151-
return s3_input(model_uri, input_mode='File', distribution='FullyReplicated',
152-
content_type='application/x-sagemaker-model')
153-
elif model_uri.startswith('file://'):
154-
return file_input(model_uri)
155-
else:
156-
raise ValueError('Model URI must be a valid S3 or FILE URI: must start with "s3://" or '
157-
'"file://')
150+
def _format_model_uri_input(model_uri, validate_uri=True):
151+
if isinstance(model_uri, string_types)and validate_uri and model_uri.startswith('s3://'):
152+
return s3_input(model_uri, input_mode='File', distribution='FullyReplicated',
153+
content_type='application/x-sagemaker-model')
154+
elif isinstance(model_uri, string_types) and validate_uri and model_uri.startswith('file://'):
155+
return file_input(model_uri)
156+
elif isinstance(model_uri, string_types) and validate_uri:
157+
raise ValueError('Model URI must be a valid S3 or FILE URI: must start with "s3://" or '
158+
'"file://')
159+
elif isinstance(model_uri, string_types):
160+
return s3_input(model_uri, input_mode='File', distribution='FullyReplicated',
161+
content_type='application/x-sagemaker-model')
158162
else:
159163
raise ValueError('Cannot format model URI {}. Expecting str'.format(model_uri))
160164

src/sagemaker/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,11 @@ def default_bucket(self):
169169
if self._default_bucket:
170170
return self._default_bucket
171171

172-
s3 = self.boto_session.resource('s3')
173172
account = self.boto_session.client('sts').get_caller_identity()['Account']
174173
region = self.boto_session.region_name
175174
default_bucket = 'sagemaker-{}-{}'.format(region, account)
176175

176+
s3 = self.boto_session.resource('s3')
177177
try:
178178
# 'us-east-1' cannot be specified because it is the default region:
179179
# https://github.com/boto/boto3/issues/125

src/sagemaker/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import six
2727

2828

29+
AIRFLOW_TIME_MACRO = "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}"
30+
31+
2932
# Use the base name of the image as the job name if the user doesn't give us one
3033
def name_from_image(image):
3134
"""Create a training job name based on the image name and a timestamp.
@@ -58,6 +61,21 @@ def name_from_base(base, max_length=63, short=False):
5861
return '{}-{}'.format(trimmed_base, timestamp)
5962

6063

64+
def airflow_name_from_base(base):
65+
"""Append airflow execution_date macro (https://airflow.apache.org/code.html?#macros)
66+
to the provided string. The macro will beevaluated in Airflow operator runtime.
67+
This guarantees that different operators will have same name returned by this function.
68+
69+
Args:
70+
base (str): String used as prefix to generate the unique name.
71+
72+
Returns:
73+
str: Input parameter with appended macro.
74+
"""
75+
76+
return "{}-{}".format(base, AIRFLOW_TIME_MACRO)
77+
78+
6179
def base_name_from_image(image):
6280
"""Extract the base name of the image to use as the 'algorithm name' for the job.
6381

src/sagemaker/workflow/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.

src/sagemaker/workflow/airflow.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import print_function, absolute_import
14+
15+
import os
16+
17+
import sagemaker
18+
from sagemaker import job, utils, model
19+
from sagemaker.amazon import amazon_estimator
20+
21+
22+
def prepare_framework(estimator, s3_operations):
23+
"""Prepare S3 operations (specify where to upload source_dir) and environment variables
24+
related to framework.
25+
26+
Args:
27+
estimator (sagemaker.estimator.Estimator): The framework estimator to get information from and update.
28+
s3_operations (dict): The dict to specify s3 operations (upload source_dir).
29+
"""
30+
bucket = estimator.code_location if estimator.code_location else estimator.sagemaker_session._default_bucket
31+
key = '{}/source/sourcedir.tar.gz'.format(estimator._current_job_name)
32+
script = os.path.basename(estimator.entry_point)
33+
if estimator.source_dir and estimator.source_dir.lower().startswith('s3://'):
34+
code_dir = estimator.source_dir
35+
else:
36+
code_dir = 's3://{}/{}'.format(bucket, key)
37+
s3_operations['S3Upload'] = [{
38+
'Path': estimator.source_dir or script,
39+
'Bucket': bucket,
40+
'Key': key,
41+
'Tar': True
42+
}]
43+
estimator._hyperparameters[model.DIR_PARAM_NAME] = code_dir
44+
estimator._hyperparameters[model.SCRIPT_PARAM_NAME] = script
45+
estimator._hyperparameters[model.CLOUDWATCH_METRICS_PARAM_NAME] = estimator.enable_cloudwatch_metrics
46+
estimator._hyperparameters[model.CONTAINER_LOG_LEVEL_PARAM_NAME] = estimator.container_log_level
47+
estimator._hyperparameters[model.JOB_NAME_PARAM_NAME] = estimator._current_job_name
48+
estimator._hyperparameters[model.SAGEMAKER_REGION_PARAM_NAME] = estimator.sagemaker_session.boto_region_name
49+
50+
51+
def prepare_amazon_algorithm_estimator(estimator, inputs):
52+
""" Set up amazon algorithm estimator, adding the required `feature_dim` hyperparameter from training data.
53+
54+
Args:
55+
estimator (sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
56+
An estimator for a built-in Amazon algorithm to get information from and update.
57+
inputs (single or list of sagemaker.amazon.amazon_estimator.RecordSet):
58+
The training data, must be in RecordSet format.
59+
"""
60+
if isinstance(inputs, list):
61+
for record in inputs:
62+
if isinstance(record, amazon_estimator.RecordSet) and record.channel == 'train':
63+
estimator.feature_dim = record.feature_dim
64+
break
65+
elif isinstance(inputs, amazon_estimator.RecordSet):
66+
estimator.feature_dim = inputs.feature_dim
67+
else:
68+
raise TypeError('Training data must be represented in RecordSet or list of RecordSets')
69+
70+
71+
def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - suppress complexity warning for this method
72+
"""Export Airflow training config from an estimator
73+
74+
Args:
75+
estimator (sagemaker.estimator.EstimatroBase):
76+
The estimator to export training config from. Can be a BYO estimator,
77+
Framework estimator or Amazon algorithm estimator.
78+
inputs (str, dict, single or list of sagemaker.amazon.amazon_estimator.RecordSet):
79+
The training data.
80+
job_name (str): Specify a training job name if needed.
81+
82+
Returns:
83+
A dict of training config that can be directly used by SageMakerTrainingOperator
84+
in Airflow.
85+
"""
86+
default_bucket = estimator.sagemaker_session.default_bucket()
87+
s3_operations = {}
88+
89+
if job_name is not None:
90+
estimator._current_job_name = job_name
91+
else:
92+
base_name = estimator.base_job_name or utils.base_name_from_image(estimator.train_image())
93+
estimator._current_job_name = utils.airflow_name_from_base(base_name)
94+
95+
if estimator.output_path is None:
96+
estimator.output_path = 's3://{}/'.format(default_bucket)
97+
98+
if isinstance(estimator, sagemaker.estimator.Framework):
99+
prepare_framework(estimator, s3_operations)
100+
101+
elif isinstance(estimator, amazon_estimator.AmazonAlgorithmEstimatorBase):
102+
prepare_amazon_algorithm_estimator(estimator, inputs)
103+
104+
job_config = job._Job._load_config(inputs, estimator, expand_role=False, validate_uri=False)
105+
106+
train_config = {
107+
'AlgorithmSpecification': {
108+
'TrainingImage': estimator.train_image(),
109+
'TrainingInputMode': estimator.input_mode
110+
},
111+
'OutputDataConfig': job_config['output_config'],
112+
'TrainingJobName': estimator._current_job_name,
113+
'StoppingCondition': job_config['stop_condition'],
114+
'ResourceConfig': job_config['resource_config'],
115+
'RoleArn': job_config['role'],
116+
}
117+
118+
if job_config['input_config'] is not None:
119+
train_config['InputDataConfig'] = job_config['input_config']
120+
121+
if job_config['vpc_config'] is not None:
122+
train_config['VpcConfig'] = job_config['vpc_config']
123+
124+
if estimator.hyperparameters() is not None:
125+
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
126+
127+
if hyperparameters and len(hyperparameters) > 0:
128+
train_config['HyperParameters'] = hyperparameters
129+
130+
if estimator.tags is not None:
131+
train_config['Tags'] = estimator.tags
132+
133+
if s3_operations:
134+
train_config['S3Operations'] = s3_operations
135+
136+
return train_config

0 commit comments

Comments
 (0)