@@ -61,6 +61,59 @@ def to_request(self) -> RequestType:
61
61
return config
62
62
63
63
64
+ def validate_cluster_config (cluster_config , name ):
65
+ """Validates user provided cluster_config.
66
+
67
+ Args:
68
+ cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
69
+ user provided cluster configuration.
70
+ name: name of the EMR cluster.
71
+ """
72
+
73
+ instances = "Instances"
74
+ instancegroups = "InstanceGroups"
75
+ instancefleets = "InstanceFleets"
76
+ prefix_with_in = "In EMRStep " + name + ", "
77
+
78
+ if (
79
+ "Name" in cluster_config
80
+ or "AutoTerminationPolicy" in cluster_config
81
+ or "Steps" in cluster_config
82
+ ):
83
+ raise Exception (
84
+ prefix_with_in + "cluster_config should not contain any of Name, "
85
+ "AutoTerminationPolicy and/or Steps"
86
+ )
87
+
88
+ if instances not in cluster_config :
89
+ raise Exception (prefix_with_in + "cluster_config must contain Instances" )
90
+
91
+ if (
92
+ "KeepJobFlowAliveWhenNoSteps" in cluster_config [instances ]
93
+ or "TerminationProtected" in cluster_config [instances ]
94
+ ):
95
+ raise Exception (
96
+ prefix_with_in + instances + " should not contain "
97
+ "KeepJobFlowAliveWhenNoSteps or "
98
+ "TerminationProtected"
99
+ )
100
+
101
+ if (
102
+ instancegroups in cluster_config [instances ] and instancefleets in cluster_config [instances ]
103
+ ) or (
104
+ instancegroups not in cluster_config [instances ]
105
+ and instancefleets not in cluster_config [instances ]
106
+ ):
107
+ raise Exception (
108
+ prefix_with_in
109
+ + instances
110
+ + " should contain either "
111
+ + instancegroups
112
+ + " or "
113
+ + instancefleets
114
+ )
115
+
116
+
64
117
class EMRStep (Step ):
65
118
"""EMR step for workflow."""
66
119
@@ -73,6 +126,7 @@ def __init__(
73
126
step_config : EMRStepConfig ,
74
127
depends_on : Optional [List [Union [str , Step , StepCollection ]]] = None ,
75
128
cache_config : CacheConfig = None ,
129
+ cluster_config : RequestType = None ,
76
130
):
77
131
"""Constructs a EMRStep.
78
132
@@ -86,16 +140,46 @@ def __init__(
86
140
names or `Step` instances or `StepCollection` instances that this `EMRStep`
87
141
depends on.
88
142
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
143
+ cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]): The recipe of the
144
+ EMR Cluster. It is a dictionary.
145
+ The elements are defined in the Request Syntax Section:
146
+ https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html
147
+ However, the following five elements are restricted, and must not present
148
+ in the dictionary:
149
+ 1. cluster_config[Name]
150
+ 2. cluster_config[Steps]
151
+ 3. cluster_config[AutoTerminationPolicy]
152
+ 4. cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]
153
+ 5. cluster_config[Instances][TerminationProtected]
154
+ Note that, if user wants to use cluster_config, then they have to explicitly set
155
+ cluster_id as None
89
156
90
157
"""
91
158
super (EMRStep , self ).__init__ (name , display_name , description , StepTypeEnum .EMR , depends_on )
92
159
93
- emr_step_args = {"ClusterId" : cluster_id , "StepConfig" : step_config .to_request ()}
160
+ emr_step_args = {"StepConfig" : step_config .to_request ()}
161
+ root_property = Properties (step_name = name , shape_name = "Step" , service_name = "emr" )
162
+
163
+ if cluster_id is None and cluster_config is None :
164
+ raise Exception ("EMRStep " + name + " must have either cluster_id or cluster_config" )
165
+
166
+ if cluster_id is not None and cluster_config is not None :
167
+ raise Exception (
168
+ "EMRStep " + name + " can not have both cluster_id or cluster_config. "
169
+ "If user wants to use cluster_config, then they "
170
+ "have to explicitly set cluster_id as None"
171
+ )
172
+
173
+ if cluster_id is not None :
174
+ emr_step_args ["ClusterId" ] = cluster_id
175
+ root_property .__dict__ ["ClusterId" ] = cluster_id
176
+ elif cluster_config is not None :
177
+ validate_cluster_config (cluster_config , name )
178
+ emr_step_args ["ClusterConfig" ] = cluster_config
179
+ root_property .__dict__ ["ClusterConfig" ] = cluster_config
180
+
94
181
self .args = emr_step_args
95
182
self .cache_config = cache_config
96
-
97
- root_property = Properties (step_name = name , shape_name = "Step" , service_name = "emr" )
98
- root_property .__dict__ ["ClusterId" ] = cluster_id
99
183
self ._properties = root_property
100
184
101
185
@property
0 commit comments