Skip to content

Commit 9defc87

Browse files
author
Shegufta Ahsan
committed
feature: support cluster lifecycle management for Sagemaker EMR step
1 parent 7faf9ce commit 9defc87

File tree

3 files changed

+414
-4
lines changed

3 files changed

+414
-4
lines changed

src/sagemaker/workflow/emr_step.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,59 @@ def to_request(self) -> RequestType:
6161
return config
6262

6363

64+
def validate_cluster_config(cluster_config, name):
65+
"""Validates user provided cluster_config.
66+
67+
Args:
68+
cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
69+
user provided cluster configuration.
70+
name: name of the EMR cluster.
71+
"""
72+
73+
instances = "Instances"
74+
instancegroups = "InstanceGroups"
75+
instancefleets = "InstanceFleets"
76+
prefix_with_in = "In EMRStep " + name + ", "
77+
78+
if (
79+
"Name" in cluster_config
80+
or "AutoTerminationPolicy" in cluster_config
81+
or "Steps" in cluster_config
82+
):
83+
raise Exception(
84+
prefix_with_in + "cluster_config should not contain any of Name, "
85+
"AutoTerminationPolicy and/or Steps"
86+
)
87+
88+
if instances not in cluster_config:
89+
raise Exception(prefix_with_in + "cluster_config must contain Instances")
90+
91+
if (
92+
"KeepJobFlowAliveWhenNoSteps" in cluster_config[instances]
93+
or "TerminationProtected" in cluster_config[instances]
94+
):
95+
raise Exception(
96+
prefix_with_in + instances + " should not contain "
97+
"KeepJobFlowAliveWhenNoSteps or "
98+
"TerminationProtected"
99+
)
100+
101+
if (
102+
instancegroups in cluster_config[instances] and instancefleets in cluster_config[instances]
103+
) or (
104+
instancegroups not in cluster_config[instances]
105+
and instancefleets not in cluster_config[instances]
106+
):
107+
raise Exception(
108+
prefix_with_in
109+
+ instances
110+
+ " should contain either "
111+
+ instancegroups
112+
+ " or "
113+
+ instancefleets
114+
)
115+
116+
64117
class EMRStep(Step):
65118
"""EMR step for workflow."""
66119

@@ -73,6 +126,7 @@ def __init__(
73126
step_config: EMRStepConfig,
74127
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
75128
cache_config: CacheConfig = None,
129+
cluster_config: RequestType = None,
76130
):
77131
"""Constructs a EMRStep.
78132
@@ -86,16 +140,46 @@ def __init__(
86140
names or `Step` instances or `StepCollection` instances that this `EMRStep`
87141
depends on.
88142
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
143+
cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]): The recipe of the
144+
EMR Cluster. It is a dictionary.
145+
The elements are defined in the Request Syntax Section:
146+
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html
147+
However, the following five elements are restricted, and must not present
148+
in the dictionary:
149+
1. cluster_config[Name]
150+
2. cluster_config[Steps]
151+
3. cluster_config[AutoTerminationPolicy]
152+
4. cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]
153+
5. cluster_config[Instances][TerminationProtected]
154+
Note that, if user wants to use cluster_config, then they have to explicitly set
155+
cluster_id as None
89156
90157
"""
91158
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
92159

93-
emr_step_args = {"ClusterId": cluster_id, "StepConfig": step_config.to_request()}
160+
emr_step_args = {"StepConfig": step_config.to_request()}
161+
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
162+
163+
if cluster_id is None and cluster_config is None:
164+
raise Exception("EMRStep " + name + " must have either cluster_id or cluster_config")
165+
166+
if cluster_id is not None and cluster_config is not None:
167+
raise Exception(
168+
"EMRStep " + name + " can not have both cluster_id or cluster_config. "
169+
"If user wants to use cluster_config, then they "
170+
"have to explicitly set cluster_id as None"
171+
)
172+
173+
if cluster_id is not None:
174+
emr_step_args["ClusterId"] = cluster_id
175+
root_property.__dict__["ClusterId"] = cluster_id
176+
elif cluster_config is not None:
177+
validate_cluster_config(cluster_config, name)
178+
emr_step_args["ClusterConfig"] = cluster_config
179+
root_property.__dict__["ClusterConfig"] = cluster_config
180+
94181
self.args = emr_step_args
95182
self.cache_config = cache_config
96-
97-
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
98-
root_property.__dict__["ClusterId"] = cluster_id
99183
self._properties = root_property
100184

101185
@property

tests/integ/sagemaker/workflow/test_emr_steps.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,57 @@ def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_n
8080
pipeline.delete()
8181
except Exception:
8282
pass
83+
84+
85+
def test_emr_with_cluster_config(sagemaker_session, role, pipeline_name, region_name):
86+
87+
emr_step_config = EMRStepConfig(
88+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
89+
args=["dummy_emr_script_path"],
90+
)
91+
92+
cluster_config = {
93+
"Instances": {
94+
"InstanceGroups": [
95+
{
96+
"Name": "Master Instance Group",
97+
"InstanceRole": "MASTER",
98+
"InstanceCount": 1,
99+
"InstanceType": "m1.small",
100+
"Market": "ON_DEMAND",
101+
}
102+
],
103+
"InstanceCount": 1,
104+
"HadoopVersion": "MyHadoopVersion",
105+
},
106+
"AmiVersion": "3.8.0",
107+
"AdditionalInfo": "MyAdditionalInfo",
108+
}
109+
110+
step_emr_with_cluster_config = EMRStep(
111+
name="MyEMRStep-name",
112+
display_name="MyEMRStep-display_name",
113+
description="MyEMRStepDescription",
114+
cluster_id=None,
115+
step_config=emr_step_config,
116+
cluster_config=cluster_config,
117+
)
118+
119+
pipeline = Pipeline(
120+
name=pipeline_name,
121+
steps=[step_emr_with_cluster_config],
122+
sagemaker_session=sagemaker_session,
123+
)
124+
125+
try:
126+
response = pipeline.create(role)
127+
create_arn = response["PipelineArn"]
128+
assert re.match(
129+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
130+
create_arn,
131+
)
132+
finally:
133+
try:
134+
pipeline.delete()
135+
except Exception:
136+
pass

0 commit comments

Comments
 (0)