Skip to content

Commit 4f54ab9

Browse files
authored
breaking: make instance_type optional for Airflow model configs (#1627)
1 parent 39c33a2 commit 4f54ab9

File tree

7 files changed

+191
-19
lines changed

7 files changed

+191
-19
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
2424
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
2525
modifiers.tfs.TensorFlowServingConstructorRenamer(),
26+
modifiers.airflow.ModelConfigArgModifier(),
2627
]
2728

2829
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
17+
airflow,
1718
deprecated_params,
1819
framework_version,
1920
tf_legacy_mode,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class to handle argument changes for Airflow functions."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
19+
20+
21+
class ModelConfigArgModifier(Modifier):
22+
"""A class to handle argument changes for Airflow model config functions."""
23+
24+
FUNCTION_NAMES = ("model_config", "model_config_from_estimator")
25+
26+
def node_should_be_modified(self, node):
27+
"""Checks if the ``ast.Call`` node creates an Airflow model config and
28+
contains positional arguments.
29+
30+
This looks for the following formats:
31+
32+
- ``model_config``
33+
- ``airflow.model_config``
34+
- ``workflow.airflow.model_config``
35+
- ``sagemaker.workflow.airflow.model_config``
36+
37+
where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``.
38+
39+
Args:
40+
node (ast.Call): a node that represents a function call. For more,
41+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
42+
43+
Returns:
44+
bool: If the ``ast.Call`` is either a ``model_config`` call or
45+
a ``model_config_from_estimator`` call and has positional arguments.
46+
"""
47+
return self._is_model_config_call(node) and len(node.args) > 0
48+
49+
def _is_model_config_call(self, node):
50+
"""Checks if the node is a ``model_config`` or ``model_config_from_estimator`` call."""
51+
if isinstance(node.func, ast.Name):
52+
return node.func.id in self.FUNCTION_NAMES
53+
54+
if not (isinstance(node.func, ast.Attribute) and node.func.attr in self.FUNCTION_NAMES):
55+
return False
56+
57+
return self._is_in_module(node.func, "sagemaker.workflow.airflow".split("."))
58+
59+
def _is_in_module(self, node, module):
60+
"""Checks if the node is in the module, including partial matches to the module path."""
61+
if isinstance(node.value, ast.Name):
62+
return node.value.id == module[-1]
63+
64+
if isinstance(node.value, ast.Attribute) and node.value.attr == module[-1]:
65+
return self._is_in_module(node.value, module[:-1])
66+
67+
return False
68+
69+
def modify_node(self, node):
70+
"""Modifies the ``ast.Call`` node's arguments.
71+
72+
The first argument, the instance type, is turned into a keyword arg,
73+
leaving the second argument, the model, to be the first argument.
74+
75+
Args:
76+
node (ast.Call): a node that represents either a ``model_config`` call or
77+
a ``model_config_from_estimator`` call.
78+
"""
79+
instance_type = node.args.pop(0)
80+
node.keywords.append(ast.keyword(arg="instance_type", value=instance_type))

src/sagemaker/workflow/airflow.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -557,29 +557,28 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
557557
return sagemaker.container_def(deploy_image, model.model_data, deploy_env)
558558

559559

560-
def model_config(instance_type, model, role=None, image=None):
560+
def model_config(model, instance_type=None, role=None, image=None):
561561
"""Export Airflow model config from a SageMaker model
562562
563563
Args:
564+
model (sagemaker.model.Model): The Model object from which to export the Airflow config
564565
instance_type (str): The EC2 instance type to deploy this Model to. For
565566
example, 'ml.p2.xlarge'
566-
model (sagemaker.model.FrameworkModel): The SageMaker model to export
567-
Airflow config from
568567
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
569568
image (str): An container image to use for deploying the model
570569
571570
Returns:
572571
dict: Model config that can be directly used by SageMakerModelOperator
573-
in Airflow. It can also be part of the config used by
574-
SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
572+
in Airflow. It can also be part of the config used by
573+
SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
575574
"""
576575
s3_operations = {}
577576
model.image = image or model.image
578577

579578
if isinstance(model, sagemaker.model.FrameworkModel):
580579
container_def = prepare_framework_container_def(model, instance_type, s3_operations)
581580
else:
582-
container_def = model.prepare_container_def(instance_type)
581+
container_def = model.prepare_container_def()
583582
base_name = utils.base_name_from_image(container_def["Image"])
584583
model.name = model.name or utils.name_from_base(base_name)
585584

@@ -601,10 +600,10 @@ def model_config(instance_type, model, role=None, image=None):
601600

602601

603602
def model_config_from_estimator(
604-
instance_type,
605603
estimator,
606604
task_id,
607605
task_type,
606+
instance_type=None,
608607
role=None,
609608
image=None,
610609
name=None,
@@ -614,8 +613,6 @@ def model_config_from_estimator(
614613
"""Export Airflow model config from a SageMaker estimator
615614
616615
Args:
617-
instance_type (str): The EC2 instance type to deploy this Model to. For
618-
example, 'ml.p2.xlarge'
619616
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to
620617
export Airflow config from. It has to be an estimator associated
621618
with a training job.
@@ -627,6 +624,8 @@ def model_config_from_estimator(
627624
task_type (str): Whether the task is from SageMakerTrainingOperator or
628625
SageMakerTuningOperator. Values can be 'training', 'tuning' or None
629626
(which means training job is not from any task).
627+
instance_type (str): The EC2 instance type to deploy this Model to. For
628+
example, 'ml.p2.xlarge'
630629
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
631630
image (str): An container image to use for deploying the model
632631
name (str): Name of the model
@@ -667,7 +666,7 @@ def model_config_from_estimator(
667666
)
668667
model.name = name
669668

670-
return model_config(instance_type, model, role, image)
669+
return model_config(model, instance_type, role, image)
671670

672671

673672
def transform_config(
@@ -914,10 +913,10 @@ def transform_config_from_estimator(
914913
SageMakerTransformOperator in Airflow.
915914
"""
916915
model_base_config = model_config_from_estimator(
917-
instance_type=instance_type,
918916
estimator=estimator,
919917
task_id=task_id,
920918
task_type=task_type,
919+
instance_type=instance_type,
921920
role=role,
922921
image=image,
923922
name=model_name,
@@ -997,7 +996,7 @@ def deploy_config(model, initial_instance_count, instance_type, endpoint_name=No
997996
dict: Deploy config that can be directly used by
998997
SageMakerEndpointOperator in Airflow.
999998
"""
1000-
model_base_config = model_config(instance_type, model)
999+
model_base_config = model_config(model, instance_type)
10011000

10021001
production_variant = sagemaker.production_variant(
10031002
model.name, instance_type, initial_instance_count

tests/integ/test_airflow_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def _build_airflow_workflow(estimator, instance_type, inputs=None, mini_batch_si
613613
model = estimator.create_model()
614614
assert model is not None
615615

616-
model_config = sm_airflow.model_config(instance_type, model)
616+
model_config = sm_airflow.model_config(model, instance_type)
617617
assert model_config is not None
618618

619619
transform_config = sm_airflow.transform_config_from_estimator(
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import airflow
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
19+
20+
21+
def test_node_should_be_modified_model_config_with_args():
22+
model_config_calls = (
23+
"model_config(instance_type, model)",
24+
"airflow.model_config(instance_type, model)",
25+
"workflow.airflow.model_config(instance_type, model)",
26+
"sagemaker.workflow.airflow.model_config(instance_type, model)",
27+
"model_config_from_estimator(instance_type, model)",
28+
"airflow.model_config_from_estimator(instance_type, model)",
29+
"workflow.airflow.model_config_from_estimator(instance_type, model)",
30+
"sagemaker.workflow.airflow.model_config_from_estimator(instance_type, model)",
31+
)
32+
33+
modifier = airflow.ModelConfigArgModifier()
34+
35+
for call in model_config_calls:
36+
node = ast_call(call)
37+
assert modifier.node_should_be_modified(node) is True
38+
39+
40+
def test_node_should_be_modified_model_config_without_args():
41+
model_config_calls = (
42+
"model_config()",
43+
"airflow.model_config()",
44+
"workflow.airflow.model_config()",
45+
"sagemaker.workflow.airflow.model_config()",
46+
"model_config_from_estimator()",
47+
"airflow.model_config_from_estimator()",
48+
"workflow.airflow.model_config_from_estimator()",
49+
"sagemaker.workflow.airflow.model_config_from_estimator()",
50+
)
51+
52+
modifier = airflow.ModelConfigArgModifier()
53+
54+
for call in model_config_calls:
55+
node = ast_call(call)
56+
assert modifier.node_should_be_modified(node) is False
57+
58+
59+
def test_node_should_be_modified_random_function_call():
60+
node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()")
61+
modifier = airflow.ModelConfigArgModifier()
62+
assert modifier.node_should_be_modified(node) is False
63+
64+
65+
def test_modify_node():
66+
model_config_calls = (
67+
("model_config(instance_type, model)", "model_config(model, instance_type=instance_type)"),
68+
(
69+
"model_config('ml.m4.xlarge', 'my-model')",
70+
"model_config('my-model', instance_type='ml.m4.xlarge')",
71+
),
72+
(
73+
"model_config('ml.m4.xlarge', model='my-model')",
74+
"model_config(instance_type='ml.m4.xlarge', model='my-model')",
75+
),
76+
(
77+
"model_config_from_estimator(instance_type, estimator, task_id, task_type)",
78+
"model_config_from_estimator(estimator, task_id, task_type, instance_type=instance_type)",
79+
),
80+
(
81+
"model_config_from_estimator(instance_type, estimator, task_id=task_id, task_type=task_type)",
82+
"model_config_from_estimator(estimator, instance_type=instance_type, task_id=task_id, task_type=task_type)",
83+
),
84+
)
85+
86+
modifier = airflow.ModelConfigArgModifier()
87+
88+
for call, expected in model_config_calls:
89+
node = ast_call(call)
90+
modifier.modify_node(node)
91+
assert expected == pasta.dump(node)

tests/unit/test_airflow.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def test_byo_model_config(sagemaker_session):
900900
sagemaker_session=sagemaker_session,
901901
)
902902

903-
config = airflow.model_config(instance_type="ml.c4.xlarge", model=byo_model)
903+
config = airflow.model_config(model=byo_model)
904904
expected_config = {
905905
"ModelName": "model",
906906
"PrimaryContainer": {
@@ -926,7 +926,7 @@ def test_byo_framework_model_config(sagemaker_session):
926926
sagemaker_session=sagemaker_session,
927927
)
928928

929-
config = airflow.model_config(instance_type="ml.c4.xlarge", model=byo_model)
929+
config = airflow.model_config(model=byo_model, instance_type="ml.c4.xlarge")
930930
expected_config = {
931931
"ModelName": "model",
932932
"PrimaryContainer": {
@@ -971,7 +971,7 @@ def test_framework_model_config(sagemaker_session):
971971
sagemaker_session=sagemaker_session,
972972
)
973973

974-
config = airflow.model_config(instance_type="ml.c4.xlarge", model=chainer_model)
974+
config = airflow.model_config(model=chainer_model, instance_type="ml.c4.xlarge")
975975
expected_config = {
976976
"ModelName": "sagemaker-chainer-%s" % TIME_STAMP,
977977
"PrimaryContainer": {
@@ -1009,7 +1009,7 @@ def test_amazon_alg_model_config(sagemaker_session):
10091009
model_data="{{ model_data }}", role="{{ role }}", sagemaker_session=sagemaker_session
10101010
)
10111011

1012-
config = airflow.model_config(instance_type="ml.c4.xlarge", model=pca_model)
1012+
config = airflow.model_config(model=pca_model)
10131013
expected_config = {
10141014
"ModelName": "pca-%s" % TIME_STAMP,
10151015
"PrimaryContainer": {
@@ -1059,10 +1059,10 @@ def test_model_config_from_framework_estimator(ecr_prefix, sagemaker_session):
10591059
airflow.training_config(mxnet_estimator, data)
10601060

10611061
config = airflow.model_config_from_estimator(
1062-
instance_type="ml.c4.xlarge",
10631062
estimator=mxnet_estimator,
10641063
task_id="task_id",
10651064
task_type="training",
1065+
instance_type="ml.c4.xlarge",
10661066
)
10671067
expected_config = {
10681068
"ModelName": "mxnet-inference-%s" % TIME_STAMP,
@@ -1103,7 +1103,7 @@ def test_model_config_from_amazon_alg_estimator(sagemaker_session):
11031103
airflow.training_config(knn_estimator, record, mini_batch_size=256)
11041104

11051105
config = airflow.model_config_from_estimator(
1106-
instance_type="ml.c4.xlarge", estimator=knn_estimator, task_id="task_id", task_type="tuning"
1106+
estimator=knn_estimator, task_id="task_id", task_type="tuning"
11071107
)
11081108
expected_config = {
11091109
"ModelName": "knn-%s" % TIME_STAMP,

0 commit comments

Comments
 (0)