Skip to content

Commit f68d41e

Browse files
authored
Merge branch 'aws:master' into model_config
2 parents b472b3f + f09f64b commit f68d41e

File tree

4 files changed

+449
-8
lines changed

4 files changed

+449
-8
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,8 @@ Steps
169169

170170
.. autoclass:: sagemaker.workflow.fail_step.FailStep
171171

172+
.. autoclass:: sagemaker.workflow.emr_step.EMRStepConfig
173+
174+
.. autoclass:: sagemaker.workflow.emr_step.EMRStep
175+
172176
.. autoclass:: sagemaker.workflow.automl_step.AutoMLStep

src/sagemaker/workflow/emr_step.py

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Union, Optional
16+
from typing import Any, Dict, List, Union, Optional
1717

1818
from sagemaker.workflow.entities import (
1919
RequestType,
@@ -33,7 +33,8 @@ def __init__(
3333
):
3434
"""Create a definition for input data used by an EMR cluster(job flow) step.
3535
36-
See AWS documentation on the ``StepConfig`` API for more details on the parameters.
36+
See AWS documentation for more information about the `StepConfig
37+
<https://docs.aws.amazon.com/emr/latest/APIReference/API_StepConfig.html>`_ API parameters.
3738
3839
Args:
3940
args(List[str]):
@@ -61,9 +62,89 @@ def to_request(self) -> RequestType:
6162
return config
6263

6364

65+
INSTANCES = "Instances"
66+
INSTANCEGROUPS = "InstanceGroups"
67+
INSTANCEFLEETS = "InstanceFleets"
68+
ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS = (
69+
"In EMRStep {step_name}, cluster_config "
70+
"should not contain any of the Name, "
71+
"AutoTerminationPolicy and/or Steps."
72+
)
73+
74+
ERR_STR_WITHOUT_INSTANCE = "In EMRStep {step_name}, cluster_config must contain " + INSTANCES + "."
75+
76+
ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED = (
77+
"In EMRStep {step_name}, " + INSTANCES + " should not contain "
78+
"KeepJobFlowAliveWhenNoSteps or "
79+
"TerminationProtected."
80+
)
81+
82+
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS = (
83+
"In EMRStep {step_name}, "
84+
+ INSTANCES
85+
+ " should contain either "
86+
+ INSTANCEGROUPS
87+
+ " or "
88+
+ INSTANCEFLEETS
89+
+ "."
90+
)
91+
92+
ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG = (
93+
"EMRStep {step_name} can not have both cluster_id"
94+
"or cluster_config."
95+
"To use EMRStep with "
96+
"cluster_config, cluster_id "
97+
"must be explicitly set to None."
98+
)
99+
100+
ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG = (
101+
"EMRStep {step_name} must have either cluster_id or cluster_config"
102+
)
103+
104+
64105
class EMRStep(Step):
65106
"""EMR step for workflow."""
66107

108+
def _validate_cluster_config(self, cluster_config, step_name):
109+
"""Validates user provided cluster_config.
110+
111+
Args:
112+
cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
113+
user provided cluster configuration.
114+
step_name: The name of the EMR step.
115+
"""
116+
117+
if (
118+
"Name" in cluster_config
119+
or "AutoTerminationPolicy" in cluster_config
120+
or "Steps" in cluster_config
121+
):
122+
raise ValueError(
123+
ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS.format(step_name=step_name)
124+
)
125+
126+
if INSTANCES not in cluster_config:
127+
raise ValueError(ERR_STR_WITHOUT_INSTANCE.format(step_name=step_name))
128+
129+
if (
130+
"KeepJobFlowAliveWhenNoSteps" in cluster_config[INSTANCES]
131+
or "TerminationProtected" in cluster_config[INSTANCES]
132+
):
133+
raise ValueError(
134+
ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED.format(step_name=step_name)
135+
)
136+
137+
if (
138+
INSTANCEGROUPS in cluster_config[INSTANCES]
139+
and INSTANCEFLEETS in cluster_config[INSTANCES]
140+
) or (
141+
INSTANCEGROUPS not in cluster_config[INSTANCES]
142+
and INSTANCEFLEETS not in cluster_config[INSTANCES]
143+
):
144+
raise ValueError(
145+
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS.format(step_name=step_name)
146+
)
147+
67148
def __init__(
68149
self,
69150
name: str,
@@ -73,8 +154,9 @@ def __init__(
73154
step_config: EMRStepConfig,
74155
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
75156
cache_config: CacheConfig = None,
157+
cluster_config: Dict[str, Any] = None,
76158
):
77-
"""Constructs a EMRStep.
159+
"""Constructs an `EMRStep`.
78160
79161
Args:
80162
name(str): The name of the EMR step.
@@ -86,16 +168,46 @@ def __init__(
86168
names or `Step` instances or `StepCollection` instances that this `EMRStep`
87169
depends on.
88170
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
171+
cluster_config(Dict[str, Any]): The recipe of the
172+
EMR cluster, passed as a dictionary.
173+
The elements are defined in the request syntax for `RunJobFlow`.
174+
However, the following elements are not recognized as part of the cluster
175+
configuration and you should not include them in the dictionary:
176+
177+
* ``cluster_config[Name]``
178+
* ``cluster_config[Steps]``
179+
* ``cluster_config[AutoTerminationPolicy]``
180+
* ``cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]``
181+
* ``cluster_config[Instances][TerminationProtected]``
182+
183+
For more information about the fields you can include in your cluster
184+
configuration, see
185+
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html.
186+
Note that if you want to use ``cluster_config``, then you have to set
187+
``cluster_id`` as None.
89188
90189
"""
91190
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
92191

93-
emr_step_args = {"ClusterId": cluster_id, "StepConfig": step_config.to_request()}
192+
emr_step_args = {"StepConfig": step_config.to_request()}
193+
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
194+
195+
if cluster_id is None and cluster_config is None:
196+
raise ValueError(ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))
197+
198+
if cluster_id is not None and cluster_config is not None:
199+
raise ValueError(ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))
200+
201+
if cluster_id is not None:
202+
emr_step_args["ClusterId"] = cluster_id
203+
root_property.__dict__["ClusterId"] = cluster_id
204+
elif cluster_config is not None:
205+
self._validate_cluster_config(cluster_config, name)
206+
emr_step_args["ClusterConfig"] = cluster_config
207+
root_property.__dict__["ClusterConfig"] = cluster_config
208+
94209
self.args = emr_step_args
95210
self.cache_config = cache_config
96-
97-
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
98-
root_property.__dict__["ClusterId"] = cluster_id
99211
self._properties = root_property
100212

101213
@property

tests/integ/sagemaker/workflow/test_emr_steps.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,57 @@ def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_n
8080
pipeline.delete()
8181
except Exception:
8282
pass
83+
84+
85+
def test_emr_with_cluster_config(sagemaker_session, role, pipeline_name, region_name):
86+
87+
emr_step_config = EMRStepConfig(
88+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
89+
args=["dummy_emr_script_path"],
90+
)
91+
92+
cluster_config = {
93+
"Instances": {
94+
"InstanceGroups": [
95+
{
96+
"Name": "Master Instance Group",
97+
"InstanceRole": "MASTER",
98+
"InstanceCount": 1,
99+
"InstanceType": "m1.small",
100+
"Market": "ON_DEMAND",
101+
}
102+
],
103+
"InstanceCount": 1,
104+
"HadoopVersion": "MyHadoopVersion",
105+
},
106+
"AmiVersion": "3.8.0",
107+
"AdditionalInfo": "MyAdditionalInfo",
108+
}
109+
110+
step_emr_with_cluster_config = EMRStep(
111+
name="MyEMRStep-name",
112+
display_name="MyEMRStep-display_name",
113+
description="MyEMRStepDescription",
114+
cluster_id=None,
115+
step_config=emr_step_config,
116+
cluster_config=cluster_config,
117+
)
118+
119+
pipeline = Pipeline(
120+
name=pipeline_name,
121+
steps=[step_emr_with_cluster_config],
122+
sagemaker_session=sagemaker_session,
123+
)
124+
125+
try:
126+
response = pipeline.create(role)
127+
create_arn = response["PipelineArn"]
128+
assert re.match(
129+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
130+
create_arn,
131+
)
132+
finally:
133+
try:
134+
pipeline.delete()
135+
except Exception:
136+
pass

0 commit comments

Comments
 (0)