Skip to content

Commit 0768754

Browse files
chenxyEthanShouhanCheng
authored andcommitted
feature: Add EMRStep support in Sagemaker pipeline
1 parent b3c19d8 commit 0768754

File tree

6 files changed

+420
-2
lines changed

6 files changed

+420
-2
lines changed

src/sagemaker/workflow/emr_step.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The step definitions for workflow."""
14+
from __future__ import absolute_import
15+
16+
from typing import List
17+
18+
from sagemaker.workflow.entities import (
19+
RequestType,
20+
)
21+
from sagemaker.workflow.properties import (
22+
Properties,
23+
)
24+
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
25+
26+
27+
class EMRStepConfig:
28+
"""Config for a Hadoop Jar step."""
29+
30+
def __init__(
31+
self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None
32+
):
33+
"""Create a definition for input data used by an EMR cluster(job flow) step.
34+
35+
See AWS documentation on the ``StepConfig`` API for more details on the parameters.
36+
37+
Args:
38+
args(List[str]):
39+
A list of command line arguments passed to
40+
the JAR file's main function when executed.
41+
jar(str): A path to a JAR file run during the step.
42+
main_class(str): The name of the main class in the specified Java file.
43+
properties(List(dict)): A list of key-value pairs that are set when the step runs.
44+
"""
45+
self.jar = jar
46+
self.args = args
47+
self.main_class = main_class
48+
self.properties = properties
49+
50+
def to_request(self) -> RequestType:
51+
"""Convert EMRStepConfig object to request dict."""
52+
config = {"HadoopJarStep": {"Jar": self.jar}}
53+
if self.args is not None:
54+
config["HadoopJarStep"]["Args"] = self.args
55+
if self.main_class is not None:
56+
config["HadoopJarStep"]["MainClass"] = self.main_class
57+
if self.properties is not None:
58+
config["HadoopJarStep"]["Properties"] = self.properties
59+
60+
return config
61+
62+
63+
class EMRStep(Step):
64+
"""EMR step for workflow."""
65+
66+
def __init__(
67+
self,
68+
name: str,
69+
display_name: str,
70+
description: str,
71+
job_flow_id: str,
72+
step_config: EMRStepConfig,
73+
depends_on: List[str] = None,
74+
cache_config: CacheConfig = None,
75+
):
76+
"""Constructs a LambdaStep.
77+
78+
Args:
79+
name(str): The name of the EMR step.
80+
display_name(str): The display name of the EMR step.
81+
description(str): The description of the EMR step.
82+
job_flow_id(str): A string that uniquely identifies the job flow(cluster).
83+
step_config(EMRStepConfig): One StepConfig to be executed by the job flow.
84+
depends_on(List[str]):
85+
A list of step names this `sagemaker.workflow.steps.EMRStep` depends on
86+
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
87+
88+
"""
89+
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
90+
91+
emr_step_args = {"JobFlowId": job_flow_id, "StepConfig": step_config.to_request()}
92+
self.args = emr_step_args
93+
self.cache_config = cache_config
94+
95+
self._properties = Properties(
96+
path=f"Steps.{name}", external_service_name="emr", shape_name="DescribeStepOutput"
97+
)
98+
99+
@property
100+
def arguments(self) -> RequestType:
101+
"""The arguments dict that is used to call `AddJobFlowSteps`.
102+
103+
NOTE: The AddFlowJobSteps request is not quite the args list that workflow needs.
104+
The Name attribute in AddJobFlowSteps cannot be passed; it will be set during runtime.
105+
In addition to that, we will also need to include emr job inputs and output config.
106+
"""
107+
return self.args
108+
109+
@property
110+
def properties(self) -> RequestType:
111+
"""A Properties object representing the EMR DescribeStepResponse model"""
112+
return self._properties
113+
114+
def to_request(self) -> RequestType:
115+
"""Updates the dictionary with cache configuration."""
116+
request_dict = super().to_request()
117+
if self.cache_config:
118+
request_dict.update(self.cache_config.config)
119+
return request_dict

src/sagemaker/workflow/properties.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@ class PropertiesMeta(type):
2626
"""Load an internal shapes attribute from the botocore sagemaker service model."""
2727

2828
_shapes = None
29+
_emr_shapes = None
2930
_primitive_types = {"string", "boolean", "integer", "float"}
3031

3132
def __new__(mcs, *args, **kwargs):
3233
"""Loads up the shapes from the botocore sagemaker service model."""
3334
if mcs._shapes is None:
3435
loader = botocore.loaders.Loader()
3536
model = loader.load_service_model("sagemaker", "service-2")
37+
emr_model = loader.load_service_model("emr", "service-2")
3638
mcs._shapes = model["shapes"]
39+
mcs._emr_shapes = emr_model["shapes"]
3740
return super().__new__(mcs, *args, **kwargs)
3841

3942

@@ -45,6 +48,7 @@ def __init__(
4548
path: str,
4649
shape_name: str = None,
4750
shape_names: List[str] = None,
51+
external_service_name: str = None,
4852
):
4953
"""Create a Properties instance representing the given shape.
5054
@@ -57,15 +61,19 @@ def __init__(
5761
shape_names = [] if shape_names is None else shape_names
5862
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names
5963

64+
shapes = Properties._shapes
65+
if external_service_name == "emr":
66+
shapes = Properties._emr_shapes
67+
6068
for name in self._shape_names:
61-
shape = Properties._shapes.get(name, {})
69+
shape = shapes.get(name, {})
6270
shape_type = shape.get("type")
6371
if shape_type in Properties._primitive_types:
6472
self.__str__ = name
6573
elif shape_type == "structure":
6674
members = shape["members"]
6775
for key, info in members.items():
68-
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
76+
if shapes.get(info["shape"], {}).get("type") == "list":
6977
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
7078
elif Properties._shapes.get(info["shape"], {}).get("type") == "map":
7179
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"])

src/sagemaker/workflow/steps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5959
LAMBDA = "Lambda"
6060
QUALITY_CHECK = "QualityCheck"
6161
CLARIFY_CHECK = "ClarifyCheck"
62+
EMR = "EMR"
6263

6364

6465
@attr.s

tests/integ/test_workflow.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
7171
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
7272
from sagemaker.workflow.properties import PropertyFile
73+
from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig
7374
from sagemaker.wrangler.processing import DataWranglerProcessor
7475
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
7576
from sagemaker.workflow.execution_variables import ExecutionVariables
@@ -1148,6 +1149,112 @@ def test_two_step_lambda_pipeline_with_output_reference(
11481149
pass
11491150

11501151

1152+
def test_one_step_emr_pipeline(sagemaker_session, role, pipeline_name, region_name):
1153+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
1154+
1155+
emr_step_config = EMRStepConfig(
1156+
jar="s3:/script-runner/script-runner.jar",
1157+
args=["--arg_0", "arg_0_value"],
1158+
main_class="com.my.main",
1159+
properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}],
1160+
)
1161+
1162+
step_emr = EMRStep(
1163+
name="emr-step1",
1164+
job_flow_id="MyClusterID",
1165+
display_name="emr_step_1",
1166+
description="MyEMRStepDescription",
1167+
step_config=emr_step_config,
1168+
)
1169+
1170+
pipeline = Pipeline(
1171+
name=pipeline_name,
1172+
parameters=[instance_count],
1173+
steps=[step_emr],
1174+
sagemaker_session=sagemaker_session,
1175+
)
1176+
1177+
try:
1178+
response = pipeline.create(role)
1179+
create_arn = response["PipelineArn"]
1180+
assert re.match(
1181+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1182+
create_arn,
1183+
)
1184+
1185+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
1186+
response = pipeline.update(role)
1187+
update_arn = response["PipelineArn"]
1188+
assert re.match(
1189+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1190+
update_arn,
1191+
)
1192+
finally:
1193+
try:
1194+
pipeline.delete()
1195+
except Exception:
1196+
pass
1197+
1198+
1199+
def test_two_steps_emr_pipeline_without_nullable_config_fields(
1200+
sagemaker_session, role, pipeline_name, region_name
1201+
):
1202+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
1203+
1204+
emr_step_config_1 = EMRStepConfig(
1205+
jar="s3:/script-runner/script-runner_1.jar",
1206+
args=["--arg_0", "arg_0_value"],
1207+
main_class="com.my.main",
1208+
properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}],
1209+
)
1210+
1211+
step_emr_1 = EMRStep(
1212+
name="emr_step_1",
1213+
job_flow_id="MyClusterID",
1214+
display_name="emr_step_1",
1215+
description="MyEMRStepDescription",
1216+
step_config=emr_step_config_1,
1217+
)
1218+
1219+
emr_step_config_2 = EMRStepConfig(jar="s3:/script-runner/script-runner_2.jar")
1220+
1221+
step_emr_2 = EMRStep(
1222+
name="emr_step_2",
1223+
job_flow_id="MyClusterID",
1224+
display_name="emr_step_2",
1225+
description="MyEMRStepDescription",
1226+
step_config=emr_step_config_2,
1227+
)
1228+
1229+
pipeline = Pipeline(
1230+
name=pipeline_name,
1231+
parameters=[instance_count],
1232+
steps=[step_emr_1, step_emr_2],
1233+
sagemaker_session=sagemaker_session,
1234+
)
1235+
1236+
try:
1237+
response = pipeline.create(role)
1238+
create_arn = response["PipelineArn"]
1239+
assert re.match(
1240+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1241+
create_arn,
1242+
)
1243+
1244+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
1245+
response = pipeline.update(role)
1246+
update_arn = response["PipelineArn"]
1247+
assert re.match(
1248+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1249+
update_arn,
1250+
)
1251+
finally:
1252+
try:
1253+
pipeline.delete()
1254+
except Exception:
1255+
pass
1256+
1257+
11511258
def test_conditional_pytorch_training_model_registration(
11521259
sagemaker_session,
11531260
role,

0 commit comments

Comments
 (0)