13
13
"""The step definitions for workflow."""
14
14
from __future__ import absolute_import
15
15
16
- from typing import List , Union , Optional
16
+ from typing import Any , Dict , List , Union , Optional
17
17
18
18
from sagemaker .workflow .entities import (
19
19
RequestType ,
@@ -33,7 +33,8 @@ def __init__(
33
33
):
34
34
"""Create a definition for input data used by an EMR cluster(job flow) step.
35
35
36
- See AWS documentation on the ``StepConfig`` API for more details on the parameters.
36
+ See AWS documentation for more information about the `StepConfig
37
+ <https://docs.aws.amazon.com/emr/latest/APIReference/API_StepConfig.html>`_ API parameters.
37
38
38
39
Args:
39
40
args(List[str]):
@@ -61,9 +62,89 @@ def to_request(self) -> RequestType:
61
62
return config
62
63
63
64
65
+ INSTANCES = "Instances"
66
+ INSTANCEGROUPS = "InstanceGroups"
67
+ INSTANCEFLEETS = "InstanceFleets"
68
+ ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS = (
69
+ "In EMRStep {step_name}, cluster_config "
70
+ "should not contain any of the Name, "
71
+ "AutoTerminationPolicy and/or Steps."
72
+ )
73
+
74
+ ERR_STR_WITHOUT_INSTANCE = "In EMRStep {step_name}, cluster_config must contain " + INSTANCES + "."
75
+
76
+ ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED = (
77
+ "In EMRStep {step_name}, " + INSTANCES + " should not contain "
78
+ "KeepJobFlowAliveWhenNoSteps or "
79
+ "TerminationProtected."
80
+ )
81
+
82
+ ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS = (
83
+ "In EMRStep {step_name}, "
84
+ + INSTANCES
85
+ + " should contain either "
86
+ + INSTANCEGROUPS
87
+ + " or "
88
+ + INSTANCEFLEETS
89
+ + "."
90
+ )
91
+
92
+ ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG = (
93
+ "EMRStep {step_name} can not have both cluster_id"
94
+ "or cluster_config."
95
+ "To use EMRStep with "
96
+ "cluster_config, cluster_id "
97
+ "must be explicitly set to None."
98
+ )
99
+
100
+ ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG = (
101
+ "EMRStep {step_name} must have either cluster_id or cluster_config"
102
+ )
103
+
104
+
64
105
class EMRStep (Step ):
65
106
"""EMR step for workflow."""
66
107
108
+ def _validate_cluster_config (self , cluster_config , step_name ):
109
+ """Validates user provided cluster_config.
110
+
111
+ Args:
112
+ cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
113
+ user provided cluster configuration.
114
+ step_name: The name of the EMR step.
115
+ """
116
+
117
+ if (
118
+ "Name" in cluster_config
119
+ or "AutoTerminationPolicy" in cluster_config
120
+ or "Steps" in cluster_config
121
+ ):
122
+ raise ValueError (
123
+ ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS .format (step_name = step_name )
124
+ )
125
+
126
+ if INSTANCES not in cluster_config :
127
+ raise ValueError (ERR_STR_WITHOUT_INSTANCE .format (step_name = step_name ))
128
+
129
+ if (
130
+ "KeepJobFlowAliveWhenNoSteps" in cluster_config [INSTANCES ]
131
+ or "TerminationProtected" in cluster_config [INSTANCES ]
132
+ ):
133
+ raise ValueError (
134
+ ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED .format (step_name = step_name )
135
+ )
136
+
137
+ if (
138
+ INSTANCEGROUPS in cluster_config [INSTANCES ]
139
+ and INSTANCEFLEETS in cluster_config [INSTANCES ]
140
+ ) or (
141
+ INSTANCEGROUPS not in cluster_config [INSTANCES ]
142
+ and INSTANCEFLEETS not in cluster_config [INSTANCES ]
143
+ ):
144
+ raise ValueError (
145
+ ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS .format (step_name = step_name )
146
+ )
147
+
67
148
def __init__ (
68
149
self ,
69
150
name : str ,
@@ -73,8 +154,9 @@ def __init__(
73
154
step_config : EMRStepConfig ,
74
155
depends_on : Optional [List [Union [str , Step , StepCollection ]]] = None ,
75
156
cache_config : CacheConfig = None ,
157
+ cluster_config : Dict [str , Any ] = None ,
76
158
):
77
- """Constructs a EMRStep.
159
+ """Constructs an ` EMRStep` .
78
160
79
161
Args:
80
162
name(str): The name of the EMR step.
@@ -86,16 +168,46 @@ def __init__(
86
168
names or `Step` instances or `StepCollection` instances that this `EMRStep`
87
169
depends on.
88
170
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
171
+ cluster_config(Dict[str, Any]): The recipe of the
172
+ EMR cluster, passed as a dictionary.
173
+ The elements are defined in the request syntax for `RunJobFlow`.
174
+ However, the following elements are not recognized as part of the cluster
175
+ configuration and you should not include them in the dictionary:
176
+
177
+ * ``cluster_config[Name]``
178
+ * ``cluster_config[Steps]``
179
+ * ``cluster_config[AutoTerminationPolicy]``
180
+ * ``cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]``
181
+ * ``cluster_config[Instances][TerminationProtected]``
182
+
183
+ For more information about the fields you can include in your cluster
184
+ configuration, see
185
+ https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html.
186
+ Note that if you want to use ``cluster_config``, then you have to set
187
+ ``cluster_id`` as None.
89
188
90
189
"""
91
190
super (EMRStep , self ).__init__ (name , display_name , description , StepTypeEnum .EMR , depends_on )
92
191
93
- emr_step_args = {"ClusterId" : cluster_id , "StepConfig" : step_config .to_request ()}
192
+ emr_step_args = {"StepConfig" : step_config .to_request ()}
193
+ root_property = Properties (step_name = name , shape_name = "Step" , service_name = "emr" )
194
+
195
+ if cluster_id is None and cluster_config is None :
196
+ raise ValueError (ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG .format (step_name = name ))
197
+
198
+ if cluster_id is not None and cluster_config is not None :
199
+ raise ValueError (ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG .format (step_name = name ))
200
+
201
+ if cluster_id is not None :
202
+ emr_step_args ["ClusterId" ] = cluster_id
203
+ root_property .__dict__ ["ClusterId" ] = cluster_id
204
+ elif cluster_config is not None :
205
+ self ._validate_cluster_config (cluster_config , name )
206
+ emr_step_args ["ClusterConfig" ] = cluster_config
207
+ root_property .__dict__ ["ClusterConfig" ] = cluster_config
208
+
94
209
self .args = emr_step_args
95
210
self .cache_config = cache_config
96
-
97
- root_property = Properties (step_name = name , shape_name = "Step" , service_name = "emr" )
98
- root_property .__dict__ ["ClusterId" ] = cluster_id
99
211
self ._properties = root_property
100
212
101
213
@property
0 commit comments