Skip to content

infra: refactor matching logic in v2 migration tool #1654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from sagemaker.cli.compatibility.v2 import modifiers

FUNCTION_CALL_MODIFIERS = [
modifiers.predictors.PredictorConstructorRefactor(),
modifiers.framework_version.FrameworkVersionEnforcer(),
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
modifiers.tfs.TensorFlowServingConstructorRenamer(),
modifiers.predictors.PredictorConstructorRefactor(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is because the TFS constructor change for Predictor --> TensorFlowPredictor is covered, so we want to prevent false positives from RealTimePredictor --> Predictor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 great call!

modifiers.airflow.ModelConfigArgModifier(),
]

Expand Down
29 changes: 6 additions & 23 deletions src/sagemaker/cli/compatibility/v2/modifiers/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@

import ast

from sagemaker.cli.compatibility.v2.modifiers import matching
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

FUNCTION_NAMES = ("model_config", "model_config_from_estimator")
NAMESPACES = ("sagemaker.workflow.airflow", "workflow.airflow", "airflow")
FUNCTIONS = {name: NAMESPACES for name in FUNCTION_NAMES}


class ModelConfigArgModifier(Modifier):
"""A class to handle argument changes for Airflow model config functions."""

FUNCTION_NAMES = ("model_config", "model_config_from_estimator")

def node_should_be_modified(self, node):
"""Checks if the ``ast.Call`` node creates an Airflow model config and
contains positional arguments.
Expand All @@ -44,27 +47,7 @@ def node_should_be_modified(self, node):
bool: If the ``ast.Call`` is either a ``model_config`` call or
a ``model_config_from_estimator`` call and has positional arguments.
"""
return self._is_model_config_call(node) and len(node.args) > 0

def _is_model_config_call(self, node):
"""Checks if the node is a ``model_config`` or ``model_config_from_estimator`` call."""
if isinstance(node.func, ast.Name):
return node.func.id in self.FUNCTION_NAMES

if not (isinstance(node.func, ast.Attribute) and node.func.attr in self.FUNCTION_NAMES):
return False

return self._is_in_module(node.func, "sagemaker.workflow.airflow".split("."))

def _is_in_module(self, node, module):
"""Checks if the node is in the module, including partial matches to the module path."""
if isinstance(node.value, ast.Name):
return node.value.id == module[-1]

if isinstance(node.value, ast.Attribute) and node.value.attr == module[-1]:
return self._is_in_module(node.value, module[:-1])

return False
return matching.matches_any(node, FUNCTIONS) and len(node.args) > 0

def modify_node(self, node):
"""Modifies the ``ast.Call`` node's arguments.
Expand Down
30 changes: 5 additions & 25 deletions src/sagemaker/cli/compatibility/v2/modifiers/deprecated_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
"""Classes to remove deprecated parameters."""
from __future__ import absolute_import

import ast

from sagemaker.cli.compatibility.v2.modifiers import matching
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

TF_NAMESPACES = ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator")


class TensorFlowScriptModeParameterRemover(Modifier):
"""A class to remove ``script_mode`` from TensorFlow estimators (because it's the only mode)."""
Expand All @@ -37,29 +38,8 @@ def node_should_be_modified(self, node):
Returns:
bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with ``script_mode``.
"""
return self._is_tf_constructor(node) and self._has_script_mode_param(node)

def _is_tf_constructor(self, node):
"""Checks if the ``ast.Call`` node represents a call of the form
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
"""
# Check for TensorFlow()
if isinstance(node.func, ast.Name):
return node.func.id == "TensorFlow"

# Check for sagemaker.tensorflow.TensorFlow()
ends_with_tensorflow_constructor = (
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
)

is_in_tensorflow_module = (
isinstance(node.func.value, ast.Attribute)
and node.func.value.attr == "tensorflow"
and isinstance(node.func.value.value, ast.Name)
and node.func.value.value.id == "sagemaker"
)

return ends_with_tensorflow_constructor and is_in_tensorflow_module
is_tf_constructor = matching.matches_name_or_namespaces(node, "TensorFlow", TF_NAMESPACES)
return is_tf_constructor and self._has_script_mode_param(node)

def _has_script_mode_param(self, node):
"""Checks if the ``ast.Call`` node's keywords include ``script_mode``."""
Expand Down
51 changes: 14 additions & 37 deletions src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ast

from sagemaker.cli.compatibility.v2.modifiers import matching
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

FRAMEWORK_ARG = "framework_version"
Expand All @@ -29,11 +30,19 @@
}

FRAMEWORK_CLASSES = list(FRAMEWORK_DEFAULTS.keys())
MODEL_CLASSES = ["{}Model".format(fw) for fw in FRAMEWORK_CLASSES]

ESTIMATORS = {
fw: ("sagemaker.{}".format(fw.lower()), "sagemaker.{}.estimator".format(fw.lower()))
for fw in FRAMEWORK_CLASSES
}
# TODO: check for sagemaker.tensorflow.serving.Model
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORK_CLASSES]
FRAMEWORK_SUBMODULES = ("model", "estimator")
MODELS = {
"{}Model".format(fw): (
"sagemaker.{}".format(fw.lower()),
"sagemaker.{}.model".format(fw.lower()),
)
for fw in FRAMEWORK_CLASSES
}


class FrameworkVersionEnforcer(Modifier):
Expand Down Expand Up @@ -61,10 +70,10 @@ def node_should_be_modified(self, node):
bool: If the ``ast.Call`` is instantiating a framework class that
should specify ``framework_version``, but doesn't.
"""
if _is_named_constructor(node, FRAMEWORK_CLASSES):
if matching.matches_any(node, ESTIMATORS):
return _version_args_needed(node, "image_name")

if _is_named_constructor(node, MODEL_CLASSES):
if matching.matches_any(node, MODELS):
return _version_args_needed(node, "image")

return False
Expand Down Expand Up @@ -160,38 +169,6 @@ def _framework_from_node(node):
return framework, is_model


def _is_named_constructor(node, names):
"""Checks if the ``ast.Call`` node represents a call to particular named constructors.

Forms that qualify are either <Framework> or sagemaker.<framework>.<Framework>
where <Framework> belongs to the list of names passed in.
"""
# Check for call from particular names of constructors
if isinstance(node.func, ast.Name):
return node.func.id in names

# Check for something.that.ends.with.<framework>.<Framework> call for Framework in names
if not (isinstance(node.func, ast.Attribute) and node.func.attr in names):
return False

# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_SUBMODULES:
return _is_in_framework_module(node.func.value)

# Check for sagemaker.<framework>.<Framework> call
return _is_in_framework_module(node.func)


def _is_in_framework_module(node):
"""Checks if node is an ``ast.Attribute`` representing a ``sagemaker.<framework>`` module."""
return (
isinstance(node.value, ast.Attribute)
and node.value.attr in FRAMEWORK_MODULES
and isinstance(node.value.value, ast.Name)
and node.value.value.id == "sagemaker"
)


def _version_args_needed(node, image_arg):
"""Determines if image_arg or version_arg was supplied

Expand Down
103 changes: 103 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Functions for checking AST nodes for matches."""
from __future__ import absolute_import

import ast


def matches_any(node, name_to_namespaces_dict):
"""Determines if the ``ast.Call`` node matches any of the provided names and namespaces.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.
name_to_namespaces_dict (dict[str, tuple]): a mapping of names to appropriate namespaces.

Returns:
bool: if the node matches any of the names and namespaces.
"""
return any(
matches_name_or_namespaces(node, name, namespaces)
for name, namespaces in name_to_namespaces_dict.items()
)


def matches_name_or_namespaces(node, name, namespaces):
"""Determines if the ``ast.Call`` node matches the function name in the right namespace.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.
name (str): the function name.
namespaces (tuple): the possible namespaces to match to.

Returns:
bool: if the node matches the name and any of the namespaces.
"""
if matches_name(node, name):
return True

if not matches_attr(node, name):
return False

return any(matches_namespace(node, namespace) for namespace in namespaces)


def matches_name(node, name):
"""Determines if the ``ast.Call`` node points to an ``ast.Name`` node with a matching name.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.
name (str): the function name.

Returns:
bool: if ``node.func`` is an ``ast.Name`` node with a matching name.
"""
return isinstance(node.func, ast.Name) and node.func.id == name


def matches_attr(node, name):
"""Determines if the ``ast.Call`` node points to an ``ast.Attribute`` node with a matching name.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.
name (str): the function name.

Returns:
bool: if ``node.func`` is an ``ast.Attribute`` node with a matching name.
"""
return isinstance(node.func, ast.Attribute) and node.func.attr == name


def matches_namespace(node, namespace):
"""Determines if the ``ast.Call`` node corresponds to a matching namespace.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.
namespace (str): the namespace.

Returns:
bool: if the node's namespaces matches the given namespace.
"""
names = namespace.split(".")
name, value = names.pop(), node.func.value
while isinstance(value, ast.Attribute) and len(names) > 0:
if value.attr != name:
return False
name, value = names.pop(), value.value

return isinstance(value, ast.Name) and value.id == name
42 changes: 4 additions & 38 deletions src/sagemaker/cli/compatibility/v2/modifiers/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
"""
from __future__ import absolute_import

import ast

from sagemaker.cli.compatibility.v2.modifiers import matching
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

BASE_PREDICTOR = "RealTimePredictor"
Expand Down Expand Up @@ -54,7 +53,7 @@ def node_should_be_modified(self, node):
Returns:
bool: If the ``ast.Call`` instantiates a class of interest.
"""
return any(_matching(node, name, namespaces) for name, namespaces in PREDICTORS.items())
return matching.matches_any(node, PREDICTORS)

def modify_node(self, node):
"""Modifies the ``ast.Call`` node to call ``Predictor`` instead.
Expand All @@ -68,44 +67,11 @@ def modify_node(self, node):
_rename_endpoint(node)


def _matching(node, name, namespaces):
"""Determines if the node matches the constructor name in the right namespace"""
if _matching_name(node, name):
return True

if not _matching_attr(node, name):
return False

return any(_matching_namespace(node, namespace) for namespace in namespaces)


def _matching_name(node, name):
"""Determines if the node is an ast.Name node with a matching name"""
return isinstance(node.func, ast.Name) and node.func.id == name


def _matching_attr(node, name):
"""Determines if the node is an ast.Attribute node with a matching name"""
return isinstance(node.func, ast.Attribute) and node.func.attr == name


def _matching_namespace(node, namespace):
"""Determines if the node corresponds to a matching namespace"""
names = namespace.split(".")
name, value = names.pop(), node.func.value
while isinstance(value, ast.Attribute) and len(names) > 0:
if value.attr != name:
return False
name, value = names.pop(), value.value

return isinstance(value, ast.Name) and value.id == name


def _rename_class(node):
"""Renames the RealTimePredictor base class to Predictor"""
if _matching_name(node, BASE_PREDICTOR):
if matching.matches_name(node, BASE_PREDICTOR):
node.func.id = "Predictor"
elif _matching_attr(node, BASE_PREDICTOR):
elif matching.matches_attr(node, BASE_PREDICTOR):
node.func.attr = "Predictor"


Expand Down
Loading