Skip to content

Commit 651aad5

Browse files
author
Shegufta Ahsan
committed
Addressed aoguo64's comments
(cherry picked from commit c2cb68e9b44d3600537daa992668a38f7529da72)
1 parent 2f5c63a commit 651aad5

File tree

2 files changed

+118
-97
lines changed

2 files changed

+118
-97
lines changed

src/sagemaker/workflow/emr_step.py

Lines changed: 92 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -61,62 +61,89 @@ def to_request(self) -> RequestType:
6161
return config
6262

6363

64-
def validate_cluster_config(cluster_config, step_name):
65-
"""Validates user provided cluster_config.
66-
67-
Args:
68-
cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
69-
user provided cluster configuration.
70-
step_name: The name of the EMR step.
71-
"""
72-
73-
instances = "Instances"
74-
instancegroups = "InstanceGroups"
75-
instancefleets = "InstanceFleets"
76-
prefix_with_in = "In EMRStep " + step_name + ", "
77-
78-
if (
79-
"Name" in cluster_config
80-
or "AutoTerminationPolicy" in cluster_config
81-
or "Steps" in cluster_config
82-
):
83-
raise Exception(
84-
prefix_with_in + "cluster_config should not contain any of Name, "
85-
"AutoTerminationPolicy and/or Steps"
86-
)
64+
instances = "Instances"
65+
instancegroups = "InstanceGroups"
66+
instancefleets = "InstanceFleets"
67+
err_str_with_name_auto_termination_or_steps = (
68+
"In EMRStep {step_name}, cluster_config "
69+
"should not contain any of the Name, "
70+
"AutoTerminationPolicy and/or Steps."
71+
)
8772

88-
if instances not in cluster_config:
89-
raise Exception(prefix_with_in + "cluster_config must contain Instances")
73+
err_str_without_instance = "In EMRStep {step_name}, cluster_config must contain " + instances + "."
9074

91-
if (
92-
"KeepJobFlowAliveWhenNoSteps" in cluster_config[instances]
93-
or "TerminationProtected" in cluster_config[instances]
94-
):
95-
raise Exception(
96-
prefix_with_in + instances + " should not contain "
97-
"KeepJobFlowAliveWhenNoSteps or "
98-
"TerminationProtected"
99-
)
100-
101-
if (
102-
instancegroups in cluster_config[instances] and instancefleets in cluster_config[instances]
103-
) or (
104-
instancegroups not in cluster_config[instances]
105-
and instancefleets not in cluster_config[instances]
106-
):
107-
raise Exception(
108-
prefix_with_in
109-
+ instances
110-
+ " should contain either "
111-
+ instancegroups
112-
+ " or "
113-
+ instancefleets
114-
)
75+
err_str_with_keepjobflow_or_terminationprotected = (
76+
"In EMRStep {step_name}, " + instances + " should not contain "
77+
"KeepJobFlowAliveWhenNoSteps or "
78+
"TerminationProtected."
79+
)
80+
81+
err_str_both_or_none_instancegroups_or_instancefleets = (
82+
"In EMRStep {step_name}, "
83+
+ instances
84+
+ " should contain either "
85+
+ instancegroups
86+
+ " or "
87+
+ instancefleets
88+
+ "."
89+
)
90+
91+
err_str_with_both_cluster_id_and_cluster_cfg = (
92+
"EMRStep {step_name} can not have both cluster_id"
93+
"or cluster_config."
94+
"To use EMRStep with "
95+
"cluster_config, cluster_id "
96+
"must be explicitly set to None."
97+
)
98+
99+
err_str_without_cluster_id_and_cluster_cfg = (
100+
"EMRStep {step_name} must have either cluster_id or cluster_config"
101+
)
115102

116103

117104
class EMRStep(Step):
118105
"""EMR step for workflow."""
119106

107+
def _validate_cluster_config(self, cluster_config, step_name):
108+
"""Validates user provided cluster_config.
109+
110+
Args:
111+
cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
112+
user provided cluster configuration.
113+
step_name: The name of the EMR step.
114+
"""
115+
116+
if (
117+
"Name" in cluster_config
118+
or "AutoTerminationPolicy" in cluster_config
119+
or "Steps" in cluster_config
120+
):
121+
raise ValueError(
122+
err_str_with_name_auto_termination_or_steps.format(step_name=step_name)
123+
)
124+
125+
if instances not in cluster_config:
126+
raise ValueError(err_str_without_instance.format(step_name=step_name))
127+
128+
if (
129+
"KeepJobFlowAliveWhenNoSteps" in cluster_config[instances]
130+
or "TerminationProtected" in cluster_config[instances]
131+
):
132+
raise ValueError(
133+
err_str_with_keepjobflow_or_terminationprotected.format(step_name=step_name)
134+
)
135+
136+
if (
137+
instancegroups in cluster_config[instances]
138+
and instancefleets in cluster_config[instances]
139+
) or (
140+
instancegroups not in cluster_config[instances]
141+
and instancefleets not in cluster_config[instances]
142+
):
143+
raise ValueError(
144+
err_str_both_or_none_instancegroups_or_instancefleets.format(step_name=step_name)
145+
)
146+
120147
def __init__(
121148
self,
122149
name: str,
@@ -128,7 +155,7 @@ def __init__(
128155
cache_config: CacheConfig = None,
129156
cluster_config: RequestType = None,
130157
):
131-
"""Constructs a EMRStep.
158+
"""Constructs an EMRStep.
132159
133160
Args:
134161
name(str): The name of the EMR step.
@@ -141,40 +168,36 @@ def __init__(
141168
depends on.
142169
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
143170
cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]): The recipe of the
144-
EMR Cluster. It is a dictionary.
145-
The elements are defined in the Request Syntax Section:
171+
EMR cluster, passed as a dictionary. The elements are defined in the request syntax
172+
for RunJobFlow. However, the following elements are not recognized as part of the
173+
cluster configuration and you should not include them in the dictionary:
174+
1. cluster_config[Name]
175+
2. cluster_config[Steps]
176+
3. cluster_config[AutoTerminationPolicy]
177+
4. cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]
178+
5. cluster_config[Instances][TerminationProtected]
179+
For more information about the fields you can include in your cluster
180+
configuration, see
146181
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html
147-
However, the following five elements are restricted, and must not present
148-
in the dictionary:
149-
1. cluster_config[Name]
150-
2. cluster_config[Steps]
151-
3. cluster_config[AutoTerminationPolicy]
152-
4. cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]
153-
5. cluster_config[Instances][TerminationProtected]
154-
Note that, if user wants to use cluster_config, then they have to explicitly set
155-
cluster_id as None
156-
182+
Note that if you want to use cluster_config,
183+
then you have to set cluster_id as None.
157184
"""
158185
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
159186

160187
emr_step_args = {"StepConfig": step_config.to_request()}
161188
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
162189

163190
if cluster_id is None and cluster_config is None:
164-
raise Exception("EMRStep " + name + " must have either cluster_id or cluster_config")
191+
raise ValueError(err_str_without_cluster_id_and_cluster_cfg.format(step_name=name))
165192

166193
if cluster_id is not None and cluster_config is not None:
167-
raise Exception(
168-
"EMRStep " + name + " can not have both cluster_id or cluster_config. "
169-
"If user wants to use cluster_config, then they "
170-
"have to explicitly set cluster_id as None"
171-
)
194+
raise ValueError(err_str_with_both_cluster_id_and_cluster_cfg.format(step_name=name))
172195

173196
if cluster_id is not None:
174197
emr_step_args["ClusterId"] = cluster_id
175198
root_property.__dict__["ClusterId"] = cluster_id
176199
elif cluster_config is not None:
177-
validate_cluster_config(cluster_config, name)
200+
self._validate_cluster_config(cluster_config, name)
178201
emr_step_args["ClusterConfig"] = cluster_config
179202
root_property.__dict__["ClusterConfig"] = cluster_config
180203

tests/unit/sagemaker/workflow/test_emr_step.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515
import json
1616
import pytest
1717

18-
from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig
18+
from sagemaker.workflow.emr_step import (
19+
EMRStep,
20+
EMRStepConfig,
21+
err_str_with_name_auto_termination_or_steps,
22+
err_str_without_instance,
23+
err_str_with_keepjobflow_or_terminationprotected,
24+
err_str_both_or_none_instancegroups_or_instancefleets,
25+
err_str_with_both_cluster_id_and_cluster_cfg,
26+
err_str_without_cluster_id_and_cluster_cfg,
27+
)
1928
from sagemaker.workflow.steps import CacheConfig
2029
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
2130
from sagemaker.workflow.parameters import ParameterString
@@ -172,8 +181,6 @@ def test_pipeline_interpolates_emr_outputs(sagemaker_session):
172181

173182
g_emr_step_config = EMRStepConfig(jar="s3:/script-runner/script-runner.jar")
174183
g_emr_step_name = "MyEMRStep"
175-
g_prefix = "EMRStep " + g_emr_step_name + " "
176-
g_prefix_with_in = "In EMRStep " + g_emr_step_name + ", "
177184
g_cluster_config = {
178185
"Instances": {
179186
"InstanceGroups": [
@@ -194,7 +201,7 @@ def test_pipeline_interpolates_emr_outputs(sagemaker_session):
194201

195202

196203
def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_present():
197-
with pytest.raises(Exception) as exceptionInfo:
204+
with pytest.raises(ValueError) as exceptionInfo:
198205
EMRStep(
199206
name=g_emr_step_name,
200207
display_name="MyEMRStep",
@@ -205,18 +212,16 @@ def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_p
205212
depends_on=["TestStep"],
206213
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
207214
)
208-
expected_error_msg = (
209-
g_prefix + "can not have both cluster_id or cluster_config. "
210-
"If user wants to use cluster_config, then they "
211-
"have to explicitly set cluster_id as None"
215+
expected_error_msg = err_str_with_both_cluster_id_and_cluster_cfg.format(
216+
step_name=g_emr_step_name
212217
)
213218
actual_error_msg = exceptionInfo.value.args[0]
214219

215220
assert actual_error_msg == expected_error_msg
216221

217222

218223
def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_none():
219-
with pytest.raises(Exception) as exceptionInfo:
224+
with pytest.raises(ValueError) as exceptionInfo:
220225
EMRStep(
221226
name=g_emr_step_name,
222227
display_name="MyEMRStep",
@@ -226,7 +231,9 @@ def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_n
226231
depends_on=["TestStep"],
227232
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
228233
)
229-
expected_error_msg = g_prefix + "must have either cluster_id or cluster_config"
234+
expected_error_msg = err_str_without_cluster_id_and_cluster_cfg.format(
235+
step_name=g_emr_step_name
236+
)
230237
actual_error_msg = exceptionInfo.value.args[0]
231238

232239
assert actual_error_msg == expected_error_msg
@@ -327,8 +334,7 @@ def test_emr_step_with_valid_cluster_config():
327334
],
328335
},
329336
},
330-
g_prefix_with_in + "cluster_config should not contain any of "
331-
"Name, AutoTerminationPolicy and/or Steps",
337+
err_str_with_name_auto_termination_or_steps.format(step_name=g_emr_step_name),
332338
),
333339
(
334340
{
@@ -341,8 +347,7 @@ def test_emr_step_with_valid_cluster_config():
341347
],
342348
},
343349
},
344-
g_prefix_with_in + "cluster_config should not contain any of "
345-
"Name, AutoTerminationPolicy and/or Steps",
350+
err_str_with_name_auto_termination_or_steps.format(step_name=g_emr_step_name),
346351
),
347352
(
348353
{
@@ -355,22 +360,20 @@ def test_emr_step_with_valid_cluster_config():
355360
],
356361
},
357362
},
358-
g_prefix_with_in + "cluster_config should not contain any of "
359-
"Name, AutoTerminationPolicy and/or Steps",
363+
err_str_with_name_auto_termination_or_steps.format(step_name=g_emr_step_name),
360364
),
361365
(
362366
{
363367
"AmiVersion": "3.8.0",
364368
"AdditionalInfo": "MyAdditionalInfo",
365369
},
366-
g_prefix_with_in + "cluster_config must contain Instances",
370+
err_str_without_instance.format(step_name=g_emr_step_name),
367371
),
368372
(
369373
{
370374
"Instances": {},
371375
},
372-
g_prefix_with_in + "Instances should contain either "
373-
"InstanceGroups or InstanceFleets",
376+
err_str_both_or_none_instancegroups_or_instancefleets.format(step_name=g_emr_step_name),
374377
),
375378
(
376379
{
@@ -387,8 +390,7 @@ def test_emr_step_with_valid_cluster_config():
387390
],
388391
},
389392
},
390-
g_prefix_with_in + "Instances should contain either "
391-
"InstanceGroups or InstanceFleets",
393+
err_str_both_or_none_instancegroups_or_instancefleets.format(step_name=g_emr_step_name),
392394
),
393395
(
394396
{
@@ -401,9 +403,7 @@ def test_emr_step_with_valid_cluster_config():
401403
"KeepJobFlowAliveWhenNoSteps": True,
402404
},
403405
},
404-
g_prefix_with_in + "Instances should not contain "
405-
"KeepJobFlowAliveWhenNoSteps or "
406-
"TerminationProtected",
406+
err_str_with_keepjobflow_or_terminationprotected.format(step_name=g_emr_step_name),
407407
),
408408
(
409409
{
@@ -416,16 +416,14 @@ def test_emr_step_with_valid_cluster_config():
416416
"TerminationProtected": True,
417417
},
418418
},
419-
g_prefix_with_in + "Instances should not contain "
420-
"KeepJobFlowAliveWhenNoSteps or "
421-
"TerminationProtected",
419+
err_str_with_keepjobflow_or_terminationprotected.format(step_name=g_emr_step_name),
422420
),
423421
],
424422
)
425423
def test_emr_step_throws_exception_when_cluster_config_contains_restricted_entities(
426424
invalid_cluster_config, expected_error_msg
427425
):
428-
with pytest.raises(Exception) as exceptionInfo:
426+
with pytest.raises(ValueError) as exceptionInfo:
429427
EMRStep(
430428
name=g_emr_step_name,
431429
display_name="MyEMRStep",

0 commit comments

Comments
 (0)