Skip to content

Commit b7d40e9

Browse files
authored
Merge branch 'master' into master
2 parents 9230dc1 + c32dec0 commit b7d40e9

File tree

10 files changed

+1038
-69
lines changed

10 files changed

+1038
-69
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ CHANGELOG
77

88
* bug-fix: Changes to use correct S3 bucket and time range for dataframes in TrainingJobAnalytics.
99
* bug-fix: Local Mode: correctly handle the case where the model output folder doesn't exist yet
10+
* feature: Add APIs to export Airflow training and tuning config
1011
* doc-fix: Fix typos in tensorflow serving documentation
1112
* doc-fix: Add estimator base classes to API docs
12-
* feature: add support for MetricDefinitions
13+
* feature: HyperparameterTuner: add support for Automatic Model Tuning's Warm Start Jobs
14+
* feature: Estimator: add support for MetricDefinitions
1315

1416
1.14.2
1517
======

src/sagemaker/session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,8 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
285285
def tune(self, job_name, strategy, objective_type, objective_metric_name,
286286
max_jobs, max_parallel_jobs, parameter_ranges,
287287
static_hyperparameters, image, input_mode, metric_definitions,
288-
role, input_config, output_config, resource_config, stop_condition, tags):
288+
role, input_config, output_config, resource_config, stop_condition, tags,
289+
warm_start_config):
289290
"""Create an Amazon SageMaker hyperparameter tuning job
290291
291292
Args:
@@ -329,6 +330,8 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
329330
stop_condition (dict): When training should finish, e.g. ``MaxRuntimeInSeconds``.
330331
tags (list[dict]): List of tags for labeling the tuning job. For more, see
331332
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
333+
warm_start_config (dict): Configuration defining the type of warm start and
334+
other required configurations.
332335
"""
333336
tune_request = {
334337
'HyperParameterTuningJobName': job_name,
@@ -358,6 +361,9 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
358361
}
359362
}
360363

364+
if warm_start_config:
365+
tune_request['WarmStartConfig'] = warm_start_config
366+
361367
if metric_definitions is not None:
362368
tune_request['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions
363369

src/sagemaker/tuner.py

Lines changed: 245 additions & 2 deletions
Large diffs are not rendered by default.

src/sagemaker/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828

2929
AIRFLOW_TIME_MACRO = "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}"
30+
AIRFLOW_TIME_MACRO_LEN = 19
31+
AIRFLOW_TIME_MACRO_SHORT = "{{ execution_date.strftime('%y%m%d-%H%M') }}"
32+
AIRFLOW_TIME_MACRO_SHORT_LEN = 11
3033

3134

3235
# Use the base name of the image as the job name if the user doesn't give us one
@@ -61,19 +64,23 @@ def name_from_base(base, max_length=63, short=False):
6164
return '{}-{}'.format(trimmed_base, timestamp)
6265

6366

64-
def airflow_name_from_base(base):
67+
def airflow_name_from_base(base, max_length=63, short=False):
6568
"""Append airflow execution_date macro (https://airflow.apache.org/code.html?#macros)
6669
to the provided string. The macro will beevaluated in Airflow operator runtime.
6770
This guarantees that different operators will have same name returned by this function.
6871
6972
Args:
7073
base (str): String used as prefix to generate the unique name.
74+
max_length (int): Maximum length for the resulting string.
75+
short (bool): Whether or not to use a truncated timestamp.
7176
7277
Returns:
7378
str: Input parameter with appended macro.
7479
"""
75-
76-
return "{}-{}".format(base, AIRFLOW_TIME_MACRO)
80+
macro = AIRFLOW_TIME_MACRO_SHORT if short else AIRFLOW_TIME_MACRO
81+
length = AIRFLOW_TIME_MACRO_SHORT_LEN if short else AIRFLOW_TIME_MACRO_LEN
82+
trimmed_base = base[:max_length - length - 1]
83+
return "{}-{}".format(trimmed_base, macro)
7784

7885

7986
def base_name_from_image(image):

src/sagemaker/workflow/airflow.py

Lines changed: 153 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616

1717
import sagemaker
18-
from sagemaker import job, utils, model
18+
from sagemaker import job, model, utils
1919
from sagemaker.amazon import amazon_estimator
2020

2121

@@ -48,14 +48,19 @@ def prepare_framework(estimator, s3_operations):
4848
estimator._hyperparameters[model.SAGEMAKER_REGION_PARAM_NAME] = estimator.sagemaker_session.boto_region_name
4949

5050

51-
def prepare_amazon_algorithm_estimator(estimator, inputs):
51+
def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None):
5252
""" Set up amazon algorithm estimator, adding the required `feature_dim` hyperparameter from training data.
5353
5454
Args:
5555
estimator (sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
5656
An estimator for a built-in Amazon algorithm to get information from and update.
57-
inputs (single or list of sagemaker.amazon.amazon_estimator.RecordSet):
58-
The training data, must be in RecordSet format.
57+
inputs: The training data.
58+
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
59+
Amazon :class:~`Record` objects serialized and stored in S3.
60+
For use with an estimator for an Amazon algorithm.
61+
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
62+
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
63+
a different channel of training data.
5964
"""
6065
if isinstance(inputs, list):
6166
for record in inputs:
@@ -66,22 +71,39 @@ def prepare_amazon_algorithm_estimator(estimator, inputs):
6671
estimator.feature_dim = inputs.feature_dim
6772
else:
6873
raise TypeError('Training data must be represented in RecordSet or list of RecordSets')
74+
estimator.mini_batch_size = mini_batch_size
6975

7076

71-
def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - suppress complexity warning for this method
72-
"""Export Airflow training config from an estimator
77+
def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None):
78+
"""Export Airflow base training config from an estimator
7379
7480
Args:
75-
estimator (sagemaker.estimator.EstimatroBase):
81+
estimator (sagemaker.estimator.EstimatorBase):
7682
The estimator to export training config from. Can be a BYO estimator,
7783
Framework estimator or Amazon algorithm estimator.
78-
inputs (str, dict, single or list of sagemaker.amazon.amazon_estimator.RecordSet):
79-
The training data.
84+
inputs: Information about the training data. Please refer to the ``fit()`` method of
85+
the associated estimator, as this can take any of the following forms:
86+
87+
* (str) - The S3 location where training data is saved.
88+
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
89+
training data, you can specify a dict mapping channel names
90+
to strings or :func:`~sagemaker.session.s3_input` objects.
91+
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
92+
additional information about the training dataset. See :func:`sagemaker.session.s3_input`
93+
for full details.
94+
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
95+
Amazon :class:~`Record` objects serialized and stored in S3.
96+
For use with an estimator for an Amazon algorithm.
97+
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
98+
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
99+
a different channel of training data.
100+
80101
job_name (str): Specify a training job name if needed.
102+
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
103+
Amazon algorithm. For other estimators, batch size should be specified in the estimator.
81104
82-
Returns:
83-
A dict of training config that can be directly used by SageMakerTrainingOperator
84-
in Airflow.
105+
Returns (dict):
106+
Training config that can be directly used by SageMakerTrainingOperator in Airflow.
85107
"""
86108
default_bucket = estimator.sagemaker_session.default_bucket()
87109
s3_operations = {}
@@ -99,8 +121,7 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
99121
prepare_framework(estimator, s3_operations)
100122

101123
elif isinstance(estimator, amazon_estimator.AmazonAlgorithmEstimatorBase):
102-
prepare_amazon_algorithm_estimator(estimator, inputs)
103-
124+
prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size)
104125
job_config = job._Job._load_config(inputs, estimator, expand_role=False, validate_uri=False)
105126

106127
train_config = {
@@ -109,7 +130,6 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
109130
'TrainingInputMode': estimator.input_mode
110131
},
111132
'OutputDataConfig': job_config['output_config'],
112-
'TrainingJobName': estimator._current_job_name,
113133
'StoppingCondition': job_config['stop_condition'],
114134
'ResourceConfig': job_config['resource_config'],
115135
'RoleArn': job_config['role'],
@@ -127,10 +147,125 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
127147
if hyperparameters and len(hyperparameters) > 0:
128148
train_config['HyperParameters'] = hyperparameters
129149

130-
if estimator.tags is not None:
131-
train_config['Tags'] = estimator.tags
132-
133150
if s3_operations:
134151
train_config['S3Operations'] = s3_operations
135152

136153
return train_config
154+
155+
156+
def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None):
157+
"""Export Airflow training config from an estimator
158+
159+
Args:
160+
estimator (sagemaker.estimator.EstimatorBase):
161+
The estimator to export training config from. Can be a BYO estimator,
162+
Framework estimator or Amazon algorithm estimator.
163+
inputs: Information about the training data. Please refer to the ``fit()`` method of
164+
the associated estimator, as this can take any of the following forms:
165+
166+
* (str) - The S3 location where training data is saved.
167+
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
168+
training data, you can specify a dict mapping channel names
169+
to strings or :func:`~sagemaker.session.s3_input` objects.
170+
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
171+
additional information about the training dataset. See :func:`sagemaker.session.s3_input`
172+
for full details.
173+
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
174+
Amazon :class:~`Record` objects serialized and stored in S3.
175+
For use with an estimator for an Amazon algorithm.
176+
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
177+
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
178+
a different channel of training data.
179+
180+
job_name (str): Specify a training job name if needed.
181+
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
182+
Amazon algorithm. For other estimators, batch size should be specified in the estimator.
183+
184+
Returns (dict):
185+
Training config that can be directly used by SageMakerTrainingOperator in Airflow.
186+
"""
187+
188+
train_config = training_base_config(estimator, inputs, job_name, mini_batch_size)
189+
190+
train_config['TrainingJobName'] = estimator._current_job_name
191+
192+
if estimator.tags is not None:
193+
train_config['Tags'] = estimator.tags
194+
195+
return train_config
196+
197+
198+
def tuning_config(tuner, inputs, job_name=None):
199+
"""Export Airflow tuning config from an estimator
200+
201+
Args:
202+
tuner (sagemaker.tuner.HyperparameterTuner): The tuner to export tuning config from.
203+
inputs: Information about the training data. Please refer to the ``fit()`` method of
204+
the associated estimator in the tuner, as this can take any of the following forms:
205+
206+
* (str) - The S3 location where training data is saved.
207+
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
208+
training data, you can specify a dict mapping channel names
209+
to strings or :func:`~sagemaker.session.s3_input` objects.
210+
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
211+
additional information about the training dataset. See :func:`sagemaker.session.s3_input`
212+
for full details.
213+
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
214+
Amazon :class:~`Record` objects serialized and stored in S3.
215+
For use with an estimator for an Amazon algorithm.
216+
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
217+
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
218+
a different channel of training data.
219+
220+
job_name (str): Specify a tuning job name if needed.
221+
222+
Returns (dict):
223+
Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
224+
"""
225+
train_config = training_base_config(tuner.estimator, inputs)
226+
hyperparameters = train_config.pop('HyperParameters', None)
227+
s3_operations = train_config.pop('S3Operations', None)
228+
229+
if hyperparameters and len(hyperparameters) > 0:
230+
tuner.static_hyperparameters = \
231+
{utils.to_str(k): utils.to_str(v) for (k, v) in hyperparameters.items()}
232+
233+
if job_name is not None:
234+
tuner._current_job_name = job_name
235+
else:
236+
base_name = tuner.base_tuning_job_name or utils.base_name_from_image(tuner.estimator.train_image())
237+
tuner._current_job_name = utils.airflow_name_from_base(base_name, tuner.TUNING_JOB_NAME_MAX_LENGTH, True)
238+
239+
for hyperparameter_name in tuner._hyperparameter_ranges.keys():
240+
tuner.static_hyperparameters.pop(hyperparameter_name, None)
241+
242+
train_config['StaticHyperParameters'] = tuner.static_hyperparameters
243+
244+
tune_config = {
245+
'HyperParameterTuningJobName': tuner._current_job_name,
246+
'HyperParameterTuningJobConfig': {
247+
'Strategy': tuner.strategy,
248+
'HyperParameterTuningJobObjective': {
249+
'Type': tuner.objective_type,
250+
'MetricName': tuner.objective_metric_name,
251+
},
252+
'ResourceLimits': {
253+
'MaxNumberOfTrainingJobs': tuner.max_jobs,
254+
'MaxParallelTrainingJobs': tuner.max_parallel_jobs,
255+
},
256+
'ParameterRanges': tuner.hyperparameter_ranges(),
257+
},
258+
'TrainingJobDefinition': train_config
259+
}
260+
261+
if tuner.metric_definitions is not None:
262+
tune_config['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = \
263+
tuner.metric_definitions
264+
265+
if tuner.tags is not None:
266+
tune_config['Tags'] = tuner.tags
267+
268+
if s3_operations is not None:
269+
tune_config['S3Operations'] = s3_operations
270+
271+
return tune_config

tests/data/local_mode_lock

Whitespace-only changes.

0 commit comments

Comments
 (0)