Skip to content

Commit a71105b

Browse files
chenxyEthanShouhanCheng
authored andcommitted
feature: Add EMRStep support in Sagemaker pipeline
1 parent 0768754 commit a71105b

File tree

5 files changed

+145
-31
lines changed

5 files changed

+145
-31
lines changed

src/sagemaker/workflow/emr_step.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def __init__(
6868
name: str,
6969
display_name: str,
7070
description: str,
71+
<<<<<<< HEAD
7172
job_flow_id: str,
73+
=======
74+
cluster_id: str,
75+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
7276
step_config: EMRStepConfig,
7377
depends_on: List[str] = None,
7478
cache_config: CacheConfig = None,
@@ -79,7 +83,11 @@ def __init__(
7983
name(str): The name of the EMR step.
8084
display_name(str): The display name of the EMR step.
8185
description(str): The description of the EMR step.
86+
<<<<<<< HEAD
8287
job_flow_id(str): A string that uniquely identifies the job flow(cluster).
88+
=======
89+
cluster_id(str): A string that uniquely identifies the cluster.
90+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
8391
step_config(EMRStepConfig): One StepConfig to be executed by the job flow.
8492
depends_on(List[str]):
8593
A list of step names this `sagemaker.workflow.steps.EMRStep` depends on
@@ -88,13 +96,23 @@ def __init__(
8896
"""
8997
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
9098

99+
<<<<<<< HEAD
91100
emr_step_args = {"JobFlowId": job_flow_id, "StepConfig": step_config.to_request()}
92101
self.args = emr_step_args
93102
self.cache_config = cache_config
94103

95104
self._properties = Properties(
96105
path=f"Steps.{name}", external_service_name="emr", shape_name="DescribeStepOutput"
97106
)
107+
=======
108+
emr_step_args = {"ClusterId": cluster_id, "StepConfig": step_config.to_request()}
109+
self.args = emr_step_args
110+
self.cache_config = cache_config
111+
112+
root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr")
113+
root_property.__dict__["ClusterId"] = cluster_id
114+
self._properties = root_property
115+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
98116

99117
@property
100118
def arguments(self) -> RequestType:

src/sagemaker/workflow/properties.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,34 @@
2323

2424

2525
class PropertiesMeta(type):
26-
"""Load an internal shapes attribute from the botocore sagemaker service model."""
26+
"""Load an internal shapes attribute from the botocore service model
2727
28+
<<<<<<< HEAD
2829
_shapes = None
2930
_emr_shapes = None
31+
=======
32+
for sagemaker and emr service.
33+
"""
34+
35+
_shapes_map = dict()
36+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
3037
_primitive_types = {"string", "boolean", "integer", "float"}
3138

3239
def __new__(mcs, *args, **kwargs):
3340
"""Loads up the shapes from the botocore sagemaker service model."""
34-
if mcs._shapes is None:
41+
if len(mcs._shapes_map.keys()) == 0:
3542
loader = botocore.loaders.Loader()
43+
<<<<<<< HEAD
3644
model = loader.load_service_model("sagemaker", "service-2")
3745
emr_model = loader.load_service_model("emr", "service-2")
3846
mcs._shapes = model["shapes"]
3947
mcs._emr_shapes = emr_model["shapes"]
48+
=======
49+
sagemaker_model = loader.load_service_model("sagemaker", "service-2")
50+
emr_model = loader.load_service_model("emr", "service-2")
51+
mcs._shapes_map["sagemaker"] = sagemaker_model["shapes"]
52+
mcs._shapes_map["emr"] = emr_model["shapes"]
53+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
4054
return super().__new__(mcs, *args, **kwargs)
4155

4256

@@ -48,7 +62,11 @@ def __init__(
4862
path: str,
4963
shape_name: str = None,
5064
shape_names: List[str] = None,
65+
<<<<<<< HEAD
5166
external_service_name: str = None,
67+
=======
68+
service_name: str = "sagemaker",
69+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
5270
):
5371
"""Create a Properties instance representing the given shape.
5472
@@ -61,9 +79,13 @@ def __init__(
6179
shape_names = [] if shape_names is None else shape_names
6280
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names
6381

82+
<<<<<<< HEAD
6483
shapes = Properties._shapes
6584
if external_service_name == "emr":
6685
shapes = Properties._emr_shapes
86+
=======
87+
shapes = Properties._shapes_map.get(service_name, {})
88+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
6789

6890
for name in self._shape_names:
6991
shape = shapes.get(name, {})
@@ -78,7 +100,9 @@ def __init__(
78100
elif Properties._shapes.get(info["shape"], {}).get("type") == "map":
79101
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"])
80102
else:
81-
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])
103+
self.__dict__[key] = Properties(
104+
f"{path}.{key}", info["shape"], service_name=service_name
105+
)
82106

83107
@property
84108
def expr(self):
@@ -89,16 +113,17 @@ def expr(self):
89113
class PropertiesList(Properties):
90114
"""PropertiesList for use in workflow expressions."""
91115

92-
def __init__(self, path: str, shape_name: str = None):
116+
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
93117
"""Create a PropertiesList instance representing the given shape.
94118
95119
Args:
96120
path (str): The parent path of the PropertiesList instance.
97-
shape_name (str): The botocore sagemaker service model shape name.
98-
root_shape_name (str): The botocore sagemaker service model shape name.
121+
shape_name (str): The botocore service model shape name.
122+
service_name (str): The botocore service name.
99123
"""
100124
super(PropertiesList, self).__init__(path, shape_name)
101125
self.shape_name = shape_name
126+
self.service_name = service_name
102127
self._items: Dict[Union[int, str], Properties] = dict()
103128

104129
def __getitem__(self, item: Union[int, str]):
@@ -108,7 +133,7 @@ def __getitem__(self, item: Union[int, str]):
108133
item (Union[int, str]): The index of the item in sequence.
109134
"""
110135
if item not in self._items.keys():
111-
shape = Properties._shapes.get(self.shape_name)
136+
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
112137
member = shape["member"]["shape"]
113138
if isinstance(item, str):
114139
property_item = Properties(f"{self._path}['{item}']", member)

tests/integ/test_workflow.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,9 +1160,9 @@ def test_one_step_emr_pipeline(sagemaker_session, role, pipeline_name, region_na
11601160
)
11611161

11621162
step_emr = EMRStep(
1163-
name="emr-step1",
1164-
job_flow_id="MyClusterID",
1165-
display_name="emr_step_1",
1163+
name="emr-step",
1164+
cluster_id="MyClusterID",
1165+
display_name="emr_step",
11661166
description="MyEMRStepDescription",
11671167
step_config=emr_step_config,
11681168
)
@@ -1177,18 +1177,19 @@ def test_one_step_emr_pipeline(sagemaker_session, role, pipeline_name, region_na
11771177
try:
11781178
response = pipeline.create(role)
11791179
create_arn = response["PipelineArn"]
1180-
assert re.match(
1181-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1182-
create_arn,
1183-
)
11841180

1185-
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
1186-
response = pipeline.update(role)
1187-
update_arn = response["PipelineArn"]
1188-
assert re.match(
1189-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1190-
update_arn,
1191-
)
1181+
execution = pipeline.start()
1182+
response = execution.describe()
1183+
assert response["PipelineArn"] == create_arn
1184+
1185+
try:
1186+
execution.wait(delay=60, max_attempts=10)
1187+
except WaiterError:
1188+
pass
1189+
1190+
execution_steps = execution.list_steps()
1191+
assert len(execution_steps) == 1
1192+
assert execution_steps[0]["StepName"] == "emr-step"
11921193
finally:
11931194
try:
11941195
pipeline.delete()
@@ -1209,19 +1210,19 @@ def test_two_steps_emr_pipeline_without_nullable_config_fields(
12091210
)
12101211

12111212
step_emr_1 = EMRStep(
1212-
name="emr_step_1",
1213-
job_flow_id="MyClusterID",
1214-
display_name="emr_step_1",
1213+
name="emr-step-1",
1214+
cluster_id="MyClusterID",
1215+
display_name="emr-step-1",
12151216
description="MyEMRStepDescription",
12161217
step_config=emr_step_config_1,
12171218
)
12181219

12191220
emr_step_config_2 = EMRStepConfig(jar="s3:/script-runner/script-runner_2.jar")
12201221

12211222
step_emr_2 = EMRStep(
1222-
name="emr_step_2",
1223-
job_flow_id="MyClusterID",
1224-
display_name="emr_step_2",
1223+
name="emr-step-2",
1224+
cluster_id="MyClusterID",
1225+
display_name="emr-step-2",
12251226
description="MyEMRStepDescription",
12261227
step_config=emr_step_config_2,
12271228
)
@@ -1236,10 +1237,20 @@ def test_two_steps_emr_pipeline_without_nullable_config_fields(
12361237
try:
12371238
response = pipeline.create(role)
12381239
create_arn = response["PipelineArn"]
1239-
assert re.match(
1240-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1241-
create_arn,
1242-
)
1240+
1241+
execution = pipeline.start()
1242+
response = execution.describe()
1243+
assert response["PipelineArn"] == create_arn
1244+
1245+
try:
1246+
execution.wait(delay=60, max_attempts=10)
1247+
except WaiterError:
1248+
pass
1249+
1250+
execution_steps = execution.list_steps()
1251+
assert len(execution_steps) == 2
1252+
assert execution_steps[0]["StepName"] == "emr-step-1"
1253+
assert execution_steps[1]["StepName"] == "emr-step-2"
12431254

12441255
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
12451256
response = pipeline.update(role)

tests/unit/sagemaker/workflow/test_emr_step.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from mock import Mock
2020

2121
from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig
22+
<<<<<<< HEAD
2223
from sagemaker.workflow.properties import Properties
24+
=======
25+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
2326
from sagemaker.workflow.steps import CacheConfig
2427
from sagemaker.workflow.pipeline import Pipeline
2528
from sagemaker.workflow.parameters import ParameterString
@@ -50,7 +53,11 @@ def test_emr_step_with_one_step_config(sagemaker_session):
5053
name="MyEMRStep",
5154
display_name="MyEMRStep",
5255
description="MyEMRStepDescription",
56+
<<<<<<< HEAD
5357
job_flow_id="MyClusterID",
58+
=======
59+
cluster_id="MyClusterID",
60+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
5461
step_config=emr_step_config,
5562
depends_on=["TestStep"],
5663
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
@@ -60,7 +67,11 @@ def test_emr_step_with_one_step_config(sagemaker_session):
6067
"Name": "MyEMRStep",
6168
"Type": "EMR",
6269
"Arguments": {
70+
<<<<<<< HEAD
6371
"JobFlowId": "MyClusterID",
72+
=======
73+
"ClusterId": "MyClusterID",
74+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
6475
"StepConfig": {
6576
"HadoopJarStep": {
6677
"Args": ["--arg_0", "arg_0_value"],
@@ -78,10 +89,27 @@ def test_emr_step_with_one_step_config(sagemaker_session):
7889
"Description": "MyEMRStepDescription",
7990
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
8091
}
92+
<<<<<<< HEAD
8193
assert emr_step.properties.Step.expr == {"Get": "Steps.MyEMRStep.Step"}
8294

8395

8496
def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
97+
=======
98+
assert emr_step.properties.ClusterId == "MyClusterID"
99+
assert emr_step.properties.ActionOnFailure.expr == {"Get": "Steps.MyEMRStep.ActionOnFailure"}
100+
assert emr_step.properties.Config.Args.expr == {"Get": "Steps.MyEMRStep.Config.Args"}
101+
assert emr_step.properties.Config.Jar.expr == {"Get": "Steps.MyEMRStep.Config.Jar"}
102+
assert emr_step.properties.Config.MainClass.expr == {"Get": "Steps.MyEMRStep.Config.MainClass"}
103+
assert emr_step.properties.Id.expr == {"Get": "Steps.MyEMRStep.Id"}
104+
assert emr_step.properties.Name.expr == {"Get": "Steps.MyEMRStep.Name"}
105+
assert emr_step.properties.Status.State.expr == {"Get": "Steps.MyEMRStep.Status.State"}
106+
assert emr_step.properties.Status.FailureDetails.Reason.expr == {
107+
"Get": "Steps.MyEMRStep.Status.FailureDetails.Reason"
108+
}
109+
110+
111+
def test_pipeline_interpolates_emr_outputs(sagemaker_session):
112+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
85113
parameter = ParameterString("MyStr")
86114

87115
emr_step_config_1 = EMRStepConfig(
@@ -93,7 +121,11 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
93121

94122
step_emr_1 = EMRStep(
95123
name="emr_step_1",
124+
<<<<<<< HEAD
96125
job_flow_id="MyClusterID",
126+
=======
127+
cluster_id="MyClusterID",
128+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
97129
display_name="emr_step_1",
98130
description="MyEMRStepDescription",
99131
depends_on=["TestStep"],
@@ -104,7 +136,11 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
104136

105137
step_emr_2 = EMRStep(
106138
name="emr_step_2",
139+
<<<<<<< HEAD
107140
job_flow_id="MyClusterID",
141+
=======
142+
cluster_id="MyClusterID",
143+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
108144
display_name="emr_step_2",
109145
description="MyEMRStepDescription",
110146
depends_on=["TestStep"],
@@ -131,7 +167,11 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
131167
"Name": "emr_step_1",
132168
"Type": "EMR",
133169
"Arguments": {
170+
<<<<<<< HEAD
134171
"JobFlowId": "MyClusterID",
172+
=======
173+
"ClusterId": "MyClusterID",
174+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
135175
"StepConfig": {
136176
"HadoopJarStep": {
137177
"Args": ["--arg_0", "arg_0_value"],
@@ -152,7 +192,11 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
152192
"Name": "emr_step_2",
153193
"Type": "EMR",
154194
"Arguments": {
195+
<<<<<<< HEAD
155196
"JobFlowId": "MyClusterID",
197+
=======
198+
"ClusterId": "MyClusterID",
199+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
156200
"StepConfig": {
157201
"HadoopJarStep": {"Jar": "s3:/script-runner/script-runner_2.jar"}
158202
},
@@ -163,6 +207,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
163207
},
164208
],
165209
}
210+
<<<<<<< HEAD
166211

167212

168213
def test_properties_describe_step_output():
@@ -172,3 +217,5 @@ def test_properties_describe_step_output():
172217
for name in some_prop_names:
173218
assert name in prop.__dict__.keys()
174219
assert prop.Step.expr == {"Get": "Steps.MyStep.Step"}
220+
=======
221+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline

tests/unit/sagemaker/workflow/test_properties.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,25 @@ def test_properties_tuning_job():
7171

7272

7373
def test_properties_emr_step():
74+
<<<<<<< HEAD
7475
prop = Properties("Steps.MyStep", "DescribeStepOutput", external_service_name="emr")
7576
some_prop_names = ["Step"]
7677
for name in some_prop_names:
7778
assert name in prop.__dict__.keys()
7879

7980
assert prop.Step.expr == {"Get": "Steps.MyStep.Step"}
81+
=======
82+
prop = Properties("Steps.MyStep", "Step", service_name="emr")
83+
some_prop_names = ["Id", "Name", "Config", "ActionOnFailure", "Status"]
84+
for name in some_prop_names:
85+
assert name in prop.__dict__.keys()
86+
87+
assert prop.Id.expr == {"Get": "Steps.MyStep.Id"}
88+
assert prop.Name.expr == {"Get": "Steps.MyStep.Name"}
89+
assert prop.ActionOnFailure.expr == {"Get": "Steps.MyStep.ActionOnFailure"}
90+
assert prop.Config.Jar.expr == {"Get": "Steps.MyStep.Config.Jar"}
91+
assert prop.Status.State.expr == {"Get": "Steps.MyStep.Status.State"}
92+
>>>>>>> feature: Add EMRStep support in Sagemaker pipeline
8093

8194

8295
def test_properties_describe_model_package_output():

0 commit comments

Comments
 (0)