Skip to content

Commit d81d464

Browse files
committed
change: handle image_uri rename in Airflow model config functions in v2 migration tool
This commit also adds some more utility functions for parsing and checking AST nodes.
1 parent 4d4dd1f commit d81d464

File tree

9 files changed

+223
-71
lines changed

9 files changed

+223
-71
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
modifiers.tfs.TensorFlowServingConstructorRenamer(),
2626
modifiers.predictors.PredictorConstructorRefactor(),
2727
modifiers.airflow.ModelConfigArgModifier(),
28+
modifiers.airflow.ModelConfigImageURIRenamer(),
2829
modifiers.renamed_params.DistributionParameterRenamer(),
2930
modifiers.renamed_params.S3SessionRenamer(),
3031
]

src/sagemaker/cli/compatibility/v2/modifiers/airflow.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import ast
1717

18-
from sagemaker.cli.compatibility.v2.modifiers import matching
18+
from sagemaker.cli.compatibility.v2.modifiers import matching, renamed_params
1919
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2020

2121
FUNCTION_NAMES = ("model_config", "model_config_from_estimator")
@@ -61,3 +61,32 @@ def modify_node(self, node):
6161
"""
6262
instance_type = node.args.pop(0)
6363
node.keywords.append(ast.keyword(arg="instance_type", value=instance_type))
64+
65+
66+
class ModelConfigImageURIRenamer(renamed_params.ParamRenamer):
67+
"""A class to rename the ``image`` attribute to ``image_uri`` in Airflow model config functions.
68+
69+
This looks for the following formats:
70+
71+
- ``model_config``
72+
- ``airflow.model_config``
73+
- ``workflow.airflow.model_config``
74+
- ``sagemaker.workflow.airflow.model_config``
75+
76+
where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``.
77+
"""
78+
79+
@property
80+
def calls_to_modify(self):
81+
"""A dictionary mapping Airflow model config functions to their respective namespaces."""
82+
return FUNCTIONS
83+
84+
@property
85+
def old_param_name(self):
86+
"""The previous name for the image URI argument."""
87+
return "image"
88+
89+
@property
90+
def new_param_name(self):
91+
"""The new name for the image URI argument."""
92+
return "image_uri"

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import ast
1717

18-
from sagemaker.cli.compatibility.v2.modifiers import matching
18+
from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
1919
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2020

2121
FRAMEWORK_ARG = "framework_version"
@@ -98,13 +98,13 @@ def modify_node(self, node):
9898
framework, is_model = _framework_from_node(node)
9999

100100
# if framework_version is not supplied, get default and append keyword
101-
framework_version = _arg_value(node, FRAMEWORK_ARG)
101+
framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
102102
if framework_version is None:
103103
framework_version = FRAMEWORK_DEFAULTS[framework]
104104
node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version)))
105105

106106
# if py_version is not supplied, get a conditional default, and if not None, append keyword
107-
py_version = _arg_value(node, PY_ARG)
107+
py_version = parsing.arg_value(node, PY_ARG)
108108
if py_version is None:
109109
py_version = _py_version_defaults(framework, framework_version, is_model)
110110
if py_version:
@@ -175,28 +175,20 @@ def _version_args_needed(node, image_arg):
175175
Applies similar logic as ``validate_version_or_image_args``
176176
"""
177177
# if image_arg is present, no need to supply version arguments
178-
image_name = _arg_value(node, image_arg)
178+
image_name = parsing.arg_value(node, image_arg)
179179
if image_name:
180180
return False
181181

182182
# if framework_version is None, need args
183-
framework_version = _arg_value(node, FRAMEWORK_ARG)
183+
framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
184184
if framework_version is None:
185185
return True
186186

187187
# check if we expect py_version and we don't get it -- framework and model dependent
188188
framework, is_model = _framework_from_node(node)
189189
expecting_py_version = _py_version_defaults(framework, framework_version, is_model)
190190
if expecting_py_version:
191-
py_version = _arg_value(node, PY_ARG)
191+
py_version = parsing.arg_value(node, PY_ARG)
192192
return py_version is None
193193

194194
return False
195-
196-
197-
def _arg_value(node, arg):
198-
"""Gets the value associated with the arg keyword, if present"""
199-
for kw in node.keywords:
200-
if kw.arg == arg and kw.value:
201-
return kw.value.s
202-
return None

src/sagemaker/cli/compatibility/v2/modifiers/matching.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import ast
1717

18+
from sagemaker.cli.compatibility.v2.modifiers import parsing
19+
1820

1921
def matches_any(node, name_to_namespaces_dict):
2022
"""Determines if the ``ast.Call`` node matches any of the provided names and namespaces.
@@ -101,3 +103,17 @@ def matches_namespace(node, namespace):
101103
name, value = names.pop(), value.value
102104

103105
return isinstance(value, ast.Name) and value.id == name
106+
107+
108+
def has_arg(node, arg):
109+
"""Checks if the call has the given argument.
110+
111+
Args:
112+
node (ast.Call): a node that represents a function call. For more,
113+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
114+
arg (str): the name of the argument.
115+
116+
Returns:
117+
bool: if the node has the given argument.
118+
"""
119+
return parsing.arg_value(node, arg) is not None
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for parsing AST nodes."""
14+
from __future__ import absolute_import
15+
16+
17+
def arg_from_keywords(node, arg):
18+
"""Retrieves a keyword argument from the node's keywords.
19+
20+
Args:
21+
node (ast.Call): a node that represents a function call. For more,
22+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
23+
arg (str): the name of the argument.
24+
25+
Returns:
26+
ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``.
27+
"""
28+
for kw in node.keywords:
29+
if kw.arg == arg:
30+
return kw
31+
32+
return None
33+
34+
35+
def arg_value(node, arg):
36+
"""Retrieves a keyword argument's value from the node's keywords.
37+
38+
Args:
39+
node (ast.Call): a node that represents a function call. For more,
40+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
41+
arg (str): the name of the argument.
42+
43+
Returns:
44+
obj: the keyword argument's value if it is present. Otherwise, this returns ``None``.
45+
"""
46+
keyword = arg_from_keywords(node, arg)
47+
if keyword and keyword.value:
48+
return getattr(keyword.value, keyword.value._fields[0], None)
49+
50+
return None

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import ast
1919
from abc import abstractmethod
2020

21-
from sagemaker.cli.compatibility.v2.modifiers import matching
21+
from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
2222
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2323

2424

@@ -54,40 +54,20 @@ def node_should_be_modified(self, node):
5454
bool: If the ``ast.Call`` matches the relevant function calls and
5555
contains the parameter to be renamed.
5656
"""
57-
return matching.matches_any(node, self.calls_to_modify) and self._has_param_to_rename(node)
58-
59-
def _has_param_to_rename(self, node):
60-
"""Checks if the call has the argument that needs to be renamed."""
61-
return _keyword_from_keywords(node, self.old_param_name) is not None
57+
return matching.matches_any(node, self.calls_to_modify) and matching.has_arg(
58+
node, self.old_param_name
59+
)
6260

6361
def modify_node(self, node):
6462
"""Modifies the ``ast.Call`` node to rename the attribute.
6563
6664
Args:
6765
node (ast.Call): a node that represents the relevant function call.
6866
"""
69-
keyword = _keyword_from_keywords(node, self.old_param_name)
67+
keyword = parsing.arg_from_keywords(node, self.old_param_name)
7068
keyword.arg = self.new_param_name
7169

7270

73-
def _keyword_from_keywords(node, param_name):
74-
"""Retrieves a keyword argument from the node's keywords.
75-
76-
Args:
77-
node (ast.Call): a node that represents a function call. For more,
78-
see https://docs.python.org/3/library/ast.html#abstract-grammar.
79-
param_name (str): the name of the argument.
80-
81-
Returns:
82-
ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``.
83-
"""
84-
for kw in node.keywords:
85-
if kw.arg == param_name:
86-
return kw
87-
88-
return None
89-
90-
9171
class DistributionParameterRenamer(ParamRenamer):
9272
"""A class to rename the ``distributions`` attribute to ``distrbution`` in
9373
MXNet and TensorFlow estimators.

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,52 +17,41 @@
1717
from sagemaker.cli.compatibility.v2.modifiers import airflow
1818
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
1919

20-
21-
def test_node_should_be_modified_model_config_with_args():
22-
model_config_calls = (
23-
"model_config(instance_type, model)",
24-
"airflow.model_config(instance_type, model)",
25-
"workflow.airflow.model_config(instance_type, model)",
26-
"sagemaker.workflow.airflow.model_config(instance_type, model)",
27-
"model_config_from_estimator(instance_type, model)",
28-
"airflow.model_config_from_estimator(instance_type, model)",
29-
"workflow.airflow.model_config_from_estimator(instance_type, model)",
30-
"sagemaker.workflow.airflow.model_config_from_estimator(instance_type, model)",
31-
)
32-
20+
MODEL_CONFIG_CALL_TEMPLATES = (
21+
"model_config({})",
22+
"airflow.model_config({})",
23+
"workflow.airflow.model_config({})",
24+
"sagemaker.workflow.airflow.model_config({})",
25+
"model_config_from_estimator({})",
26+
"airflow.model_config_from_estimator({})",
27+
"workflow.airflow.model_config_from_estimator({})",
28+
"sagemaker.workflow.airflow.model_config_from_estimator({})",
29+
)
30+
31+
32+
def test_arg_order_node_should_be_modified_model_config_with_args():
3333
modifier = airflow.ModelConfigArgModifier()
3434

35-
for call in model_config_calls:
36-
node = ast_call(call)
35+
for template in MODEL_CONFIG_CALL_TEMPLATES:
36+
node = ast_call(template.format("instance_type, model"))
3737
assert modifier.node_should_be_modified(node) is True
3838

3939

40-
def test_node_should_be_modified_model_config_without_args():
41-
model_config_calls = (
42-
"model_config()",
43-
"airflow.model_config()",
44-
"workflow.airflow.model_config()",
45-
"sagemaker.workflow.airflow.model_config()",
46-
"model_config_from_estimator()",
47-
"airflow.model_config_from_estimator()",
48-
"workflow.airflow.model_config_from_estimator()",
49-
"sagemaker.workflow.airflow.model_config_from_estimator()",
50-
)
51-
40+
def test_arg_order_node_should_be_modified_model_config_without_args():
5241
modifier = airflow.ModelConfigArgModifier()
5342

54-
for call in model_config_calls:
55-
node = ast_call(call)
43+
for template in MODEL_CONFIG_CALL_TEMPLATES:
44+
node = ast_call(template.format(""))
5645
assert modifier.node_should_be_modified(node) is False
5746

5847

59-
def test_node_should_be_modified_random_function_call():
48+
def test_arg_order_node_should_be_modified_random_function_call():
6049
node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()")
6150
modifier = airflow.ModelConfigArgModifier()
6251
assert modifier.node_should_be_modified(node) is False
6352

6453

65-
def test_modify_node():
54+
def test_arg_order_modify_node():
6655
model_config_calls = (
6756
("model_config(instance_type, model)", "model_config(model, instance_type=instance_type)"),
6857
(
@@ -89,3 +78,42 @@ def test_modify_node():
8978
node = ast_call(call)
9079
modifier.modify_node(node)
9180
assert expected == pasta.dump(node)
81+
82+
83+
def test_image_arg_node_should_be_modified_model_config_with_arg():
84+
modifier = airflow.ModelConfigImageURIRenamer()
85+
86+
for template in MODEL_CONFIG_CALL_TEMPLATES:
87+
node = ast_call(template.format("image=my_image"))
88+
assert modifier.node_should_be_modified(node) is True
89+
90+
91+
def test_image_arg_node_should_be_modified_model_config_without_arg():
92+
modifier = airflow.ModelConfigImageURIRenamer()
93+
94+
for template in MODEL_CONFIG_CALL_TEMPLATES:
95+
node = ast_call(template.format(""))
96+
assert modifier.node_should_be_modified(node) is False
97+
98+
99+
def test_image_arg_node_should_be_modified_random_function_call():
100+
node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()")
101+
modifier = airflow.ModelConfigImageURIRenamer()
102+
assert modifier.node_should_be_modified(node) is False
103+
104+
105+
def test_image_arg_modify_node():
106+
model_config_calls = (
107+
("model_config(image='image:latest')", "model_config(image_uri='image:latest')"),
108+
(
109+
"model_config_from_estimator(image=my_image)",
110+
"model_config_from_estimator(image_uri=my_image)",
111+
),
112+
)
113+
114+
modifier = airflow.ModelConfigImageURIRenamer()
115+
116+
for call, expected in model_config_calls:
117+
node = ast_call(call)
118+
modifier.modify_node(node)
119+
assert expected == pasta.dump(node)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,8 @@ def test_matches_attr():
6666
def test_matches_namespace():
6767
assert matching.matches_namespace(ast_call("sagemaker.mxnet.MXNet()"), "sagemaker.mxnet")
6868
assert not matching.matches_namespace(ast_call("sagemaker.KMeans()"), "sagemaker.mxnet")
69+
70+
71+
def test_has_arg():
72+
assert matching.has_arg(ast_call("MXNet(framework_version=mxnet_version)"), "framework_version")
73+
assert not matching.has_arg(ast_call("MXNet()"), "framework_version")

0 commit comments

Comments
 (0)