Skip to content

Commit 63b852b

Browse files
authored
Specify task of training/tuning job in Airflow transform/deploy operator (#590)
1 parent 9fd7e2b commit 63b852b

File tree

5 files changed

+246
-154
lines changed

5 files changed

+246
-154
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22
CHANGELOG
33
=========
44

5+
1.17.1.dev
6+
==========
7+
8+
* enhancement: Workflow: Specify tasks from which training/tuning operator to transform/deploy in related operators
9+
510
1.17.0
611
======
712

8-
* bug-fix: Revert appending Airflow retry id to default job name
13+
* bug-fix: Workflow: Revert appending Airflow retry id to default job name
914
* bug-fix: Session: don't allow get_execution_role() to return an ARN that's not a role but has "role" in the name
1015
* bug-fix: Remove ``__all__`` from ``__init__.py`` files
1116
* doc-fix: Add TFRecord split type to docs
@@ -22,7 +27,7 @@ CHANGELOG
2227
======
2328

2429
* bug-fix: Local Mode: Allow support for SSH in local mode
25-
* bug-fix: Append retry id to default Airflow job name to avoid name collisions in retry
30+
* bug-fix: Workflow: Append retry id to default Airflow job name to avoid name collisions in retry
2631
* bug-fix: Local Mode: No longer requires s3 permissions to run local entry point file
2732
* feature: Estimators: add support for PyTorch 1.0.0
2833
* bug-fix: Local Mode: Move dependency on sagemaker_s3_output from rl.estimator to model

src/sagemaker/utils.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@
2727
import six
2828

2929

30-
AIRFLOW_TIME_MACRO = "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}"
31-
AIRFLOW_TIME_MACRO_LEN = 19
32-
AIRFLOW_TIME_MACRO_SHORT = "{{ execution_date.strftime('%y%m%d-%H%M') }}"
33-
AIRFLOW_TIME_MACRO_SHORT_LEN = 11
34-
35-
3630
# Use the base name of the image as the job name if the user doesn't give us one
3731
def name_from_image(image):
3832
"""Create a training job name based on the image name and a timestamp.
@@ -73,25 +67,6 @@ def unique_name_from_base(base, max_length=63):
7367
return '{}-{}-{}'.format(trimmed, ts, unique)
7468

7569

76-
def airflow_name_from_base(base, max_length=63, short=False):
77-
"""Append airflow execution_date macro (https://airflow.apache.org/code.html?#macros)
78-
to the provided string. The macro will beevaluated in Airflow operator runtime.
79-
This guarantees that different operators will have same name returned by this function.
80-
81-
Args:
82-
base (str): String used as prefix to generate the unique name.
83-
max_length (int): Maximum length for the resulting string.
84-
short (bool): Whether or not to use a truncated timestamp.
85-
86-
Returns:
87-
str: Input parameter with appended macro.
88-
"""
89-
macro = AIRFLOW_TIME_MACRO_SHORT if short else AIRFLOW_TIME_MACRO
90-
length = AIRFLOW_TIME_MACRO_SHORT_LEN if short else AIRFLOW_TIME_MACRO_LEN
91-
trimmed_base = base[:max_length - length - 1]
92-
return "{}-{}".format(trimmed_base, macro)
93-
94-
9570
def base_name_from_image(image):
9671
"""Extract the base name of the image to use as the 'algorithm name' for the job.
9772

src/sagemaker/workflow/README.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ dictionary that can be generated by the SageMaker Python SDK. For example:
5757
inputs=your_training_data_s3_uri)
5858
5959
# trans_config specifies SageMaker batch transform configuration
60+
# task_id specifies which operator the training job associatd with; task_type specifies whether the operator is a
61+
# training operator or tuning operator
6062
trans_config = transform_config_from_estimator(estimator=estimator,
63+
task_id='tf_training',
64+
task_type='training',
6165
instance_count=1,
6266
instance_type='ml.m4.xlarge',
6367
data=your_transform_data_s3_uri,
@@ -82,13 +86,13 @@ Now you can pass these configurations to the corresponding SageMaker operators a
8286
schedule_interval='@once')
8387
8488
train_op = SageMakerTrainingOperator(
85-
task_id='training',
89+
task_id='tf_training',
8690
config=train_config,
8791
wait_for_completion=True,
8892
dag=dag)
8993
9094
transform_op = SageMakerTransformOperator(
91-
task_id='transform',
95+
task_id='tf_transform',
9296
config=trans_config,
9397
wait_for_completion=True,
9498
dag=dag)

0 commit comments

Comments
 (0)