Skip to content

Commit de838d1

Browse files
committed
Merge remote-tracking branch 'aws/zwei' into endpoint-rename
2 parents 10a7a7c + c233f67 commit de838d1

File tree

14 files changed

+306
-115
lines changed

14 files changed

+306
-115
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
2424
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
2525
modifiers.tfs.TensorFlowServingConstructorRenamer(),
26+
modifiers.airflow.ModelConfigArgModifier(),
2627
]
2728

2829
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
17+
airflow,
1718
deprecated_params,
1819
framework_version,
1920
tf_legacy_mode,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class to handle argument changes for Airflow functions."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
19+
20+
21+
class ModelConfigArgModifier(Modifier):
22+
"""A class to handle argument changes for Airflow model config functions."""
23+
24+
FUNCTION_NAMES = ("model_config", "model_config_from_estimator")
25+
26+
def node_should_be_modified(self, node):
27+
"""Checks if the ``ast.Call`` node creates an Airflow model config and
28+
contains positional arguments.
29+
30+
This looks for the following formats:
31+
32+
- ``model_config``
33+
- ``airflow.model_config``
34+
- ``workflow.airflow.model_config``
35+
- ``sagemaker.workflow.airflow.model_config``
36+
37+
where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``.
38+
39+
Args:
40+
node (ast.Call): a node that represents a function call. For more,
41+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
42+
43+
Returns:
44+
bool: If the ``ast.Call`` is either a ``model_config`` call or
45+
a ``model_config_from_estimator`` call and has positional arguments.
46+
"""
47+
return self._is_model_config_call(node) and len(node.args) > 0
48+
49+
def _is_model_config_call(self, node):
50+
"""Checks if the node is a ``model_config`` or ``model_config_from_estimator`` call."""
51+
if isinstance(node.func, ast.Name):
52+
return node.func.id in self.FUNCTION_NAMES
53+
54+
if not (isinstance(node.func, ast.Attribute) and node.func.attr in self.FUNCTION_NAMES):
55+
return False
56+
57+
return self._is_in_module(node.func, "sagemaker.workflow.airflow".split("."))
58+
59+
def _is_in_module(self, node, module):
60+
"""Checks if the node is in the module, including partial matches to the module path."""
61+
if isinstance(node.value, ast.Name):
62+
return node.value.id == module[-1]
63+
64+
if isinstance(node.value, ast.Attribute) and node.value.attr == module[-1]:
65+
return self._is_in_module(node.value, module[:-1])
66+
67+
return False
68+
69+
def modify_node(self, node):
70+
"""Modifies the ``ast.Call`` node's arguments.
71+
72+
The first argument, the instance type, is turned into a keyword arg,
73+
leaving the second argument, the model, to be the first argument.
74+
75+
Args:
76+
node (ast.Call): a node that represents either a ``model_config`` call or
77+
a ``model_config_from_estimator`` call.
78+
"""
79+
instance_type = node.args.pop(0)
80+
node.keywords.append(ast.keyword(arg="instance_type", value=instance_type))

src/sagemaker/workflow/airflow.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -557,29 +557,28 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
557557
return sagemaker.container_def(deploy_image, model.model_data, deploy_env)
558558

559559

560-
def model_config(instance_type, model, role=None, image=None):
560+
def model_config(model, instance_type=None, role=None, image=None):
561561
"""Export Airflow model config from a SageMaker model
562562
563563
Args:
564+
model (sagemaker.model.Model): The Model object from which to export the Airflow config
564565
instance_type (str): The EC2 instance type to deploy this Model to. For
565566
example, 'ml.p2.xlarge'
566-
model (sagemaker.model.FrameworkModel): The SageMaker model to export
567-
Airflow config from
568567
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
569568
image (str): An container image to use for deploying the model
570569
571570
Returns:
572571
dict: Model config that can be directly used by SageMakerModelOperator
573-
in Airflow. It can also be part of the config used by
574-
SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
572+
in Airflow. It can also be part of the config used by
573+
SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
575574
"""
576575
s3_operations = {}
577576
model.image = image or model.image
578577

579578
if isinstance(model, sagemaker.model.FrameworkModel):
580579
container_def = prepare_framework_container_def(model, instance_type, s3_operations)
581580
else:
582-
container_def = model.prepare_container_def(instance_type)
581+
container_def = model.prepare_container_def()
583582
base_name = utils.base_name_from_image(container_def["Image"])
584583
model.name = model.name or utils.name_from_base(base_name)
585584

@@ -601,10 +600,10 @@ def model_config(instance_type, model, role=None, image=None):
601600

602601

603602
def model_config_from_estimator(
604-
instance_type,
605603
estimator,
606604
task_id,
607605
task_type,
606+
instance_type=None,
608607
role=None,
609608
image=None,
610609
name=None,
@@ -614,8 +613,6 @@ def model_config_from_estimator(
614613
"""Export Airflow model config from a SageMaker estimator
615614
616615
Args:
617-
instance_type (str): The EC2 instance type to deploy this Model to. For
618-
example, 'ml.p2.xlarge'
619616
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to
620617
export Airflow config from. It has to be an estimator associated
621618
with a training job.
@@ -627,6 +624,8 @@ def model_config_from_estimator(
627624
task_type (str): Whether the task is from SageMakerTrainingOperator or
628625
SageMakerTuningOperator. Values can be 'training', 'tuning' or None
629626
(which means training job is not from any task).
627+
instance_type (str): The EC2 instance type to deploy this Model to. For
628+
example, 'ml.p2.xlarge'
630629
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
631630
image (str): An container image to use for deploying the model
632631
name (str): Name of the model
@@ -667,7 +666,7 @@ def model_config_from_estimator(
667666
)
668667
model.name = name
669668

670-
return model_config(instance_type, model, role, image)
669+
return model_config(model, instance_type, role, image)
671670

672671

673672
def transform_config(
@@ -914,10 +913,10 @@ def transform_config_from_estimator(
914913
SageMakerTransformOperator in Airflow.
915914
"""
916915
model_base_config = model_config_from_estimator(
917-
instance_type=instance_type,
918916
estimator=estimator,
919917
task_id=task_id,
920918
task_type=task_type,
919+
instance_type=instance_type,
921920
role=role,
922921
image=image,
923922
name=model_name,
@@ -997,7 +996,7 @@ def deploy_config(model, initial_instance_count, instance_type, endpoint_name=No
997996
dict: Deploy config that can be directly used by
998997
SageMakerEndpointOperator in Airflow.
999998
"""
1000-
model_base_config = model_config(instance_type, model)
999+
model_base_config = model_config(model, instance_type)
10011000

10021001
production_variant = sagemaker.production_variant(
10031002
model.name, instance_type, initial_instance_count

tests/conftest.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,6 @@ def pytest_addoption(parser):
4444
parser.addoption("--sagemaker-client-config", action="store", default=None)
4545
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
4646
parser.addoption("--boto-config", action="store", default=None)
47-
parser.addoption("--chainer-full-version", action="store", default="5.0.0")
48-
parser.addoption("--ei-mxnet-full-version", action="store", default="1.5.1")
49-
parser.addoption(
50-
"--rl-coach-mxnet-full-version",
51-
action="store",
52-
default=RLEstimator.COACH_LATEST_VERSION_MXNET,
53-
)
54-
parser.addoption(
55-
"--rl-coach-tf-full-version", action="store", default=RLEstimator.COACH_LATEST_VERSION_TF
56-
)
57-
parser.addoption(
58-
"--rl-ray-full-version", action="store", default=RLEstimator.RAY_LATEST_VERSION
59-
)
60-
parser.addoption("--sklearn-full-version", action="store", default="0.20.0")
61-
parser.addoption("--ei-tf-full-version", action="store")
62-
parser.addoption("--xgboost-full-version", action="store", default="1.0-1")
6347

6448

6549
def pytest_configure(config):
@@ -249,8 +233,13 @@ def rl_ray_version(request):
249233

250234

251235
@pytest.fixture(scope="module")
252-
def chainer_full_version(request):
253-
return request.config.getoption("--chainer-full-version")
236+
def chainer_full_version():
237+
return "5.0.0"
238+
239+
240+
@pytest.fixture(scope="module")
241+
def chainer_full_py_version():
242+
return "py3"
254243

255244

256245
@pytest.fixture(scope="module")
@@ -264,8 +253,8 @@ def mxnet_full_py_version():
264253

265254

266255
@pytest.fixture(scope="module")
267-
def ei_mxnet_full_version(request):
268-
return request.config.getoption("--ei-mxnet-full-version")
256+
def ei_mxnet_full_version():
257+
return "1.5.1"
269258

270259

271260
@pytest.fixture(scope="module")
@@ -284,23 +273,28 @@ def pytorch_full_ei_version():
284273

285274

286275
@pytest.fixture(scope="module")
287-
def rl_coach_mxnet_full_version(request):
288-
return request.config.getoption("--rl-coach-mxnet-full-version")
276+
def rl_coach_mxnet_full_version():
277+
return RLEstimator.COACH_LATEST_VERSION_MXNET
278+
279+
280+
@pytest.fixture(scope="module")
281+
def rl_coach_tf_full_version():
282+
return RLEstimator.COACH_LATEST_VERSION_TF
289283

290284

291285
@pytest.fixture(scope="module")
292-
def rl_coach_tf_full_version(request):
293-
return request.config.getoption("--rl-coach-tf-full-version")
286+
def rl_ray_full_version():
287+
return RLEstimator.RAY_LATEST_VERSION
294288

295289

296290
@pytest.fixture(scope="module")
297-
def rl_ray_full_version(request):
298-
return request.config.getoption("--rl-ray-full-version")
291+
def sklearn_full_version():
292+
return "0.20.0"
299293

300294

301295
@pytest.fixture(scope="module")
302-
def sklearn_full_version(request):
303-
return request.config.getoption("--sklearn-full-version")
296+
def sklearn_full_py_version():
297+
return "py3"
304298

305299

306300
@pytest.fixture(scope="module")
@@ -343,13 +337,19 @@ def tf_full_py_version(tf_full_version):
343337
return "py37"
344338

345339

346-
@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])
347-
def ei_tf_full_version(request):
348-
tf_ei_version = request.config.getoption("--ei-tf-full-version")
349-
if tf_ei_version is None:
350-
return request.param
351-
else:
352-
tf_ei_version
340+
@pytest.fixture(scope="module")
341+
def ei_tf_full_version():
342+
return "2.0.0"
343+
344+
345+
@pytest.fixture(scope="module")
346+
def xgboost_full_version():
347+
return "1.0-1"
348+
349+
350+
@pytest.fixture(scope="module")
351+
def xgboost_full_py_version():
352+
return "py3"
353353

354354

355355
@pytest.fixture(scope="session")
@@ -405,8 +405,3 @@ def pytest_generate_tests(metafunc):
405405
):
406406
params.append("ml.p2.xlarge")
407407
metafunc.parametrize("instance_type", params, scope="session")
408-
409-
410-
@pytest.fixture(scope="module")
411-
def xgboost_full_version(request):
412-
return request.config.getoption("--xgboost-full-version")

tests/integ/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616
import os
17-
import sys
1817

1918
import boto3
2019

@@ -23,7 +22,6 @@
2322
TUNING_DEFAULT_TIMEOUT_MINUTES = 20
2423
TRANSFORM_DEFAULT_TIMEOUT_MINUTES = 20
2524
AUTO_ML_DEFAULT_TIMEMOUT_MINUTES = 60
26-
PYTHON_VERSION = "py{}".format(sys.version_info.major)
2725

2826
# these regions have some p2 and p3 instances, but not enough for continuous testing
2927
HOSTING_NO_P2_REGIONS = [

0 commit comments

Comments
 (0)