Skip to content

Add airflow tuning config export API #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CHANGELOG

* bug-fix: Changes to use correct S3 bucket and time range for dataframes in TrainingJobAnalytics.
* bug-fix: Local Mode: correctly handle the case where the model output folder doesn't exist yet
* feature: Add APIs to export Airflow training and tuning config
* doc-fix: Fix typos in tensorflow serving documentation
* doc-fix: Add estimator base classes to API docs

Expand Down
13 changes: 10 additions & 3 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@


AIRFLOW_TIME_MACRO = "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}"
AIRFLOW_TIME_MACRO_LEN = 19
AIRFLOW_TIME_MACRO_SHORT = "{{ execution_date.strftime('%y%m%d-%H%M') }}"
AIRFLOW_TIME_MACRO_SHORT_LEN = 11


# Use the base name of the image as the job name if the user doesn't give us one
Expand Down Expand Up @@ -61,19 +64,23 @@ def name_from_base(base, max_length=63, short=False):
return '{}-{}'.format(trimmed_base, timestamp)


def airflow_name_from_base(base):
def airflow_name_from_base(base, max_length=63, short=False):
"""Append airflow execution_date macro (https://airflow.apache.org/code.html?#macros)
to the provided string. The macro will beevaluated in Airflow operator runtime.
This guarantees that different operators will have same name returned by this function.

Args:
base (str): String used as prefix to generate the unique name.
max_length (int): Maximum length for the resulting string.
short (bool): Whether or not to use a truncated timestamp.

Returns:
str: Input parameter with appended macro.
"""

return "{}-{}".format(base, AIRFLOW_TIME_MACRO)
macro = AIRFLOW_TIME_MACRO_SHORT if short else AIRFLOW_TIME_MACRO
length = AIRFLOW_TIME_MACRO_SHORT_LEN if short else AIRFLOW_TIME_MACRO_LEN
trimmed_base = base[:max_length - length - 1]
return "{}-{}".format(trimmed_base, macro)


def base_name_from_image(image):
Expand Down
171 changes: 153 additions & 18 deletions src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

import sagemaker
from sagemaker import job, utils, model
from sagemaker import job, model, utils
from sagemaker.amazon import amazon_estimator


Expand Down Expand Up @@ -48,14 +48,19 @@ def prepare_framework(estimator, s3_operations):
estimator._hyperparameters[model.SAGEMAKER_REGION_PARAM_NAME] = estimator.sagemaker_session.boto_region_name


def prepare_amazon_algorithm_estimator(estimator, inputs):
def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None):
""" Set up amazon algorithm estimator, adding the required `feature_dim` hyperparameter from training data.

Args:
estimator (sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
An estimator for a built-in Amazon algorithm to get information from and update.
inputs (single or list of sagemaker.amazon.amazon_estimator.RecordSet):
The training data, must be in RecordSet format.
inputs: The training data.
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
For use with an estimator for an Amazon algorithm.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
a different channel of training data.
"""
if isinstance(inputs, list):
for record in inputs:
Expand All @@ -66,22 +71,39 @@ def prepare_amazon_algorithm_estimator(estimator, inputs):
estimator.feature_dim = inputs.feature_dim
else:
raise TypeError('Training data must be represented in RecordSet or list of RecordSets')
estimator.mini_batch_size = mini_batch_size


def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - suppress complexity warning for this method
"""Export Airflow training config from an estimator
def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None):
"""Export Airflow base training config from an estimator

Args:
estimator (sagemaker.estimator.EstimatroBase):
estimator (sagemaker.estimator.EstimatorBase):
The estimator to export training config from. Can be a BYO estimator,
Framework estimator or Amazon algorithm estimator.
inputs (str, dict, single or list of sagemaker.amazon.amazon_estimator.RecordSet):
The training data.
inputs: Information about the training data. Please refer to the ``fit()`` method of
the associated estimator, as this can take any of the following forms:

* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
training data, you can specify a dict mapping channel names
to strings or :func:`~sagemaker.session.s3_input` objects.
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
additional information about the training dataset. See :func:`sagemaker.session.s3_input`
for full details.
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
For use with an estimator for an Amazon algorithm.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
a different channel of training data.

job_name (str): Specify a training job name if needed.
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
Amazon algorithm. For other estimators, batch size should be specified in the estimator.

Returns:
A dict of training config that can be directly used by SageMakerTrainingOperator
in Airflow.
Returns (dict):
Training config that can be directly used by SageMakerTrainingOperator in Airflow.
"""
default_bucket = estimator.sagemaker_session.default_bucket()
s3_operations = {}
Expand All @@ -99,8 +121,7 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
prepare_framework(estimator, s3_operations)

elif isinstance(estimator, amazon_estimator.AmazonAlgorithmEstimatorBase):
prepare_amazon_algorithm_estimator(estimator, inputs)

prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size)
job_config = job._Job._load_config(inputs, estimator, expand_role=False, validate_uri=False)

train_config = {
Expand All @@ -109,7 +130,6 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
'TrainingInputMode': estimator.input_mode
},
'OutputDataConfig': job_config['output_config'],
'TrainingJobName': estimator._current_job_name,
'StoppingCondition': job_config['stop_condition'],
'ResourceConfig': job_config['resource_config'],
'RoleArn': job_config['role'],
Expand All @@ -127,10 +147,125 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
if hyperparameters and len(hyperparameters) > 0:
train_config['HyperParameters'] = hyperparameters

if estimator.tags is not None:
train_config['Tags'] = estimator.tags

if s3_operations:
train_config['S3Operations'] = s3_operations

return train_config


def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None):
"""Export Airflow training config from an estimator

Args:
estimator (sagemaker.estimator.EstimatorBase):
The estimator to export training config from. Can be a BYO estimator,
Framework estimator or Amazon algorithm estimator.
inputs: Information about the training data. Please refer to the ``fit()`` method of
the associated estimator, as this can take any of the following forms:

* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
training data, you can specify a dict mapping channel names
to strings or :func:`~sagemaker.session.s3_input` objects.
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
additional information about the training dataset. See :func:`sagemaker.session.s3_input`
for full details.
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
For use with an estimator for an Amazon algorithm.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
a different channel of training data.

job_name (str): Specify a training job name if needed.
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
Amazon algorithm. For other estimators, batch size should be specified in the estimator.

Returns (dict):
Training config that can be directly used by SageMakerTrainingOperator in Airflow.
"""

train_config = training_base_config(estimator, inputs, job_name, mini_batch_size)

train_config['TrainingJobName'] = estimator._current_job_name

if estimator.tags is not None:
train_config['Tags'] = estimator.tags

return train_config


def tuning_config(tuner, inputs, job_name=None):
"""Export Airflow tuning config from an estimator

Args:
tuner (sagemaker.tuner.HyperparameterTuner): The tuner to export tuning config from.
inputs: Information about the training data. Please refer to the ``fit()`` method of
the associated estimator in the tuner, as this can take any of the following forms:

* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
training data, you can specify a dict mapping channel names
to strings or :func:`~sagemaker.session.s3_input` objects.
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
additional information about the training dataset. See :func:`sagemaker.session.s3_input`
for full details.
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
For use with an estimator for an Amazon algorithm.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
a different channel of training data.

job_name (str): Specify a tuning job name if needed.

Returns (dict):
Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
"""
train_config = training_base_config(tuner.estimator, inputs)
hyperparameters = train_config.pop('HyperParameters', None)
s3_operations = train_config.pop('S3Operations', None)

if hyperparameters and len(hyperparameters) > 0:
tuner.static_hyperparameters = \
{utils.to_str(k): utils.to_str(v) for (k, v) in hyperparameters.items()}

if job_name is not None:
tuner._current_job_name = job_name
else:
base_name = tuner.base_tuning_job_name or utils.base_name_from_image(tuner.estimator.train_image())
tuner._current_job_name = utils.airflow_name_from_base(base_name, tuner.TUNING_JOB_NAME_MAX_LENGTH, True)

for hyperparameter_name in tuner._hyperparameter_ranges.keys():
tuner.static_hyperparameters.pop(hyperparameter_name, None)

train_config['StaticHyperParameters'] = tuner.static_hyperparameters

tune_config = {
'HyperParameterTuningJobName': tuner._current_job_name,
'HyperParameterTuningJobConfig': {
'Strategy': tuner.strategy,
'HyperParameterTuningJobObjective': {
'Type': tuner.objective_type,
'MetricName': tuner.objective_metric_name,
},
'ResourceLimits': {
'MaxNumberOfTrainingJobs': tuner.max_jobs,
'MaxParallelTrainingJobs': tuner.max_parallel_jobs,
},
'ParameterRanges': tuner.hyperparameter_ranges(),
},
'TrainingJobDefinition': train_config
}

if tuner.metric_definitions is not None:
tune_config['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = \
tuner.metric_definitions

if tuner.tags is not None:
tune_config['Tags'] = tuner.tags

if s3_operations is not None:
tune_config['S3Operations'] = s3_operations

return tune_config
Loading