-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Specify task of training/tuning job in Airflow transform/deploy operator #590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
CHANGELOG.rst
Outdated
@@ -5,6 +5,7 @@ CHANGELOG | |||
1.17.0.dev | |||
========== | |||
|
|||
* enhancement: Specify tasks from which training/tuning operator to transform/deploy in related operators |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: and something to denote this is for Airflow/workflows
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated! Also updated other entries of Airflow in CHANGELOG.
src/sagemaker/workflow/airflow.py
Outdated
if estimator.uploaded_code is None: | ||
return | ||
|
||
postfix = os.path.join('/', 'source', 'sourcedir.tar.gz') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we're actually concerned about using the correct separator for the OS, then you should use os.path.sep
instead of '/'
. however, if this is something where we actually just need slashes, then you can just use a format string here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That line is removed since I used regex for the replace logic.
But it should be just '/' since it's in a s3 URI. Should be OS independent.
src/sagemaker/workflow/airflow.py
Outdated
# s3://path/old_job/source/sourcedir.tar.gz will become s3://path/new_job/source/sourcedir.tar.gz | ||
submit_uri = estimator.uploaded_code.s3_prefix | ||
submit_uri = submit_uri[:len(submit_uri) - len(postfix)] | ||
submit_uri = submit_uri[:submit_uri.rfind('/') + 1] + job_name + postfix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would this be potentially cleaner with a regex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! Updated using regex
src/sagemaker/workflow/airflow.py
Outdated
training_job = "{{ ti.xcom_pull(task_ids='%s')['Tuning']['BestTrainingJob']['TrainingJobName'] }}" % task_id | ||
# need to strip the double quotes in json to get the string | ||
job_name = "{{ ti.xcom_pull(task_ids='%s')['Tuning']['TrainingJobDefinition']['StaticHyperParameters']" \ | ||
"['sagemaker_job_name'].replace('%s', '') }}" % (task_id, '"') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make more sense to use strip
instead of replace
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! Updated
@@ -369,16 +421,22 @@ def model_config(instance_type, model, role=None, image=None): | |||
return config | |||
|
|||
|
|||
def model_config_from_estimator(instance_type, estimator, role=None, image=None, model_server_workers=None, | |||
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT): | |||
def model_config_from_estimator(instance_type, estimator, task_id, task_type, role=None, image=None, name=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a breaking change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address most of the comments. Need to rethink the attach
one and will provide updates tomorrow.
CHANGELOG.rst
Outdated
@@ -5,6 +5,7 @@ CHANGELOG | |||
1.17.0.dev | |||
========== | |||
|
|||
* enhancement: Specify tasks from which training/tuning operator to transform/deploy in related operators |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated! Also updated other entries of Airflow in CHANGELOG.
src/sagemaker/workflow/airflow.py
Outdated
if estimator.uploaded_code is None: | ||
return | ||
|
||
postfix = os.path.join('/', 'source', 'sourcedir.tar.gz') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That line is removed since I used regex for the replace logic.
But it should be just '/' since it's in a s3 URI. Should be OS independent.
src/sagemaker/workflow/airflow.py
Outdated
# s3://path/old_job/source/sourcedir.tar.gz will become s3://path/new_job/source/sourcedir.tar.gz | ||
submit_uri = estimator.uploaded_code.s3_prefix | ||
submit_uri = submit_uri[:len(submit_uri) - len(postfix)] | ||
submit_uri = submit_uri[:submit_uri.rfind('/') + 1] + job_name + postfix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! Updated using regex
src/sagemaker/workflow/airflow.py
Outdated
training_job = "{{ ti.xcom_pull(task_ids='%s')['Tuning']['BestTrainingJob']['TrainingJobName'] }}" % task_id | ||
# need to strip the double quotes in json to get the string | ||
job_name = "{{ ti.xcom_pull(task_ids='%s')['Tuning']['TrainingJobDefinition']['StaticHyperParameters']" \ | ||
"['sagemaker_job_name'].replace('%s', '') }}" % (task_id, '"') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! Updated
Codecov Report
@@ Coverage Diff @@
## master #590 +/- ##
==========================================
- Coverage 92.78% 92.75% -0.04%
==========================================
Files 71 71
Lines 5367 5386 +19
==========================================
+ Hits 4980 4996 +16
- Misses 387 390 +3
Continue to review full report at Codecov.
|
Issue #, if available:
Description of changes:
Specify task id and type of the training/tuning job in Airflow transform/endpoint operator. Then the API knows that which training job to transform/deploy if there are multiple jobs created in the DAG. To make this happen, additional changes are:
Merge Checklist
Put an
x
in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.