95
95
from sagemaker .feature_store .feature_group import FeatureGroup , FeatureDefinition , FeatureTypeEnum
96
96
from tests .integ import DATA_DIR
97
97
from tests .integ .kms_utils import get_or_create_kms_key
98
+ < << << << HEAD
98
99
from tests .integ .retry import retries
100
+ == == == =
101
+ from tests .integ .vpc_test_utils import get_or_create_vpc_resources
102
+ > >> >> >> 48 cd0d8f (feature : Add EMRStep support in Sagemaker pipeline )
99
103
100
104
101
105
def ordered (obj ):
@@ -281,6 +285,75 @@ def build_jar():
281
285
subprocess .run (["rm" , os .path .join (jar_file_path , java_file_path , "HelloJavaSparkApp.class" )])
282
286
283
287
288
+ @pytest .fixture (scope = "module" )
289
+ def emr_script_path (sagemaker_session ):
290
+ input_path = sagemaker_session .upload_data (
291
+ path = os .path .join (DATA_DIR , "workflow" , "emr-script.sh" ),
292
+ key_prefix = "integ-test-data/workflow" ,
293
+ )
294
+ return input_path
295
+
296
+
297
+ @pytest .fixture (scope = "module" )
298
+ def emr_cluster_id (sagemaker_session , role ):
299
+ emr_client = sagemaker_session .boto_session .client ("emr" )
300
+ cluster_name = "emr-step-test-cluster"
301
+ cluster_id = get_existing_emr_cluster_id (emr_client , cluster_name )
302
+
303
+ if cluster_id is None :
304
+ create_new_emr_cluster (sagemaker_session , emr_client , cluster_name )
305
+ return cluster_id
306
+
307
+
308
+ def get_existing_emr_cluster_id (emr_client , cluster_name ):
309
+ try :
310
+ response = emr_client .list_clusters (ClusterStates = ["RUNNING" , "WAITING" ])
311
+ for cluster in response ["Clusters" ]:
312
+ if cluster ["Name" ].startswith (cluster_name ):
313
+ cluster_id = cluster ["Id" ]
314
+ print ("Using existing cluster: {}" .format (cluster_id ))
315
+ return cluster_id
316
+ except Exception :
317
+ raise
318
+
319
+
320
+ def create_new_emr_cluster (sagemaker_session , emr_client , cluster_name ):
321
+ ec2_client = sagemaker_session .boto_session .client ("ec2" )
322
+ subnet_ids , security_group_id = get_or_create_vpc_resources (ec2_client )
323
+ try :
324
+ response = emr_client .run_job_flow (
325
+ Name = "emr-step-test-cluster" ,
326
+ LogUri = "s3://{}/{}" .format (sagemaker_session .default_bucket (), "emr-test-logs" ),
327
+ ReleaseLabel = "emr-6.3.0" ,
328
+ Applications = [
329
+ {"Name" : "Hadoop" },
330
+ {"Name" : "Spark" },
331
+ ],
332
+ Instances = {
333
+ "InstanceGroups" : [
334
+ {
335
+ "Name" : "Master nodes" ,
336
+ "Market" : "ON_DEMAND" ,
337
+ "InstanceRole" : "MASTER" ,
338
+ "InstanceType" : "m4.large" ,
339
+ "InstanceCount" : 1 ,
340
+ }
341
+ ],
342
+ "KeepJobFlowAliveWhenNoSteps" : True ,
343
+ "TerminationProtected" : False ,
344
+ "Ec2SubnetId" : subnet_ids [0 ],
345
+ },
346
+ VisibleToAllUsers = True ,
347
+ JobFlowRole = "EMR_EC2_DefaultRole" ,
348
+ ServiceRole = "EMR_DefaultRole" ,
349
+ )
350
+ cluster_id = response ["JobFlowId" ]
351
+ print ("Created new cluster: {}" .format (cluster_id ))
352
+ return cluster_id
353
+ except Exception :
354
+ raise
355
+
356
+
284
357
def test_three_step_definition (
285
358
sagemaker_session ,
286
359
region_name ,
@@ -1149,82 +1222,30 @@ def test_two_step_lambda_pipeline_with_output_reference(
1149
1222
pass
1150
1223
1151
1224
1152
- def test_one_step_emr_pipeline (sagemaker_session , role , pipeline_name , region_name ):
1153
- instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
1154
-
1155
- emr_step_config = EMRStepConfig (
1156
- jar = "s3:/script-runner/script-runner.jar" ,
1157
- args = ["--arg_0" , "arg_0_value" ],
1158
- main_class = "com.my.main" ,
1159
- properties = [{"Key" : "Foo" , "Value" : "Foo_value" }, {"Key" : "Bar" , "Value" : "Bar_value" }],
1160
- )
1161
-
1162
- step_emr = EMRStep (
1163
- name = "emr-step" ,
1164
- cluster_id = "MyClusterID" ,
1165
- display_name = "emr_step" ,
1166
- description = "MyEMRStepDescription" ,
1167
- step_config = emr_step_config ,
1168
- )
1169
-
1170
- pipeline = Pipeline (
1171
- name = pipeline_name ,
1172
- parameters = [instance_count ],
1173
- steps = [step_emr ],
1174
- sagemaker_session = sagemaker_session ,
1175
- )
1176
-
1177
- try :
1178
- response = pipeline .create (role )
1179
- create_arn = response ["PipelineArn" ]
1180
-
1181
- execution = pipeline .start ()
1182
- response = execution .describe ()
1183
- assert response ["PipelineArn" ] == create_arn
1184
-
1185
- try :
1186
- execution .wait (delay = 60 , max_attempts = 10 )
1187
- except WaiterError :
1188
- pass
1189
-
1190
- execution_steps = execution .list_steps ()
1191
- assert len (execution_steps ) == 1
1192
- assert execution_steps [0 ]["StepName" ] == "emr-step"
1193
- finally :
1194
- try :
1195
- pipeline .delete ()
1196
- except Exception :
1197
- pass
1198
-
1199
-
1200
- def test_two_steps_emr_pipeline_without_nullable_config_fields (
1201
- sagemaker_session , role , pipeline_name , region_name
1225
+ def test_two_steps_emr_pipeline (
1226
+ sagemaker_session , role , pipeline_name , region_name , emr_cluster_id , emr_script_path
1202
1227
):
1203
1228
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
1204
1229
1205
- emr_step_config_1 = EMRStepConfig (
1206
- jar = "s3:/script-runner/script-runner_1.jar" ,
1207
- args = ["--arg_0" , "arg_0_value" ],
1208
- main_class = "com.my.main" ,
1209
- properties = [{"Key" : "Foo" , "Value" : "Foo_value" }, {"Key" : "Bar" , "Value" : "Bar_value" }],
1230
+ emr_step_config = EMRStepConfig (
1231
+ jar = "s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar" ,
1232
+ args = [emr_script_path ],
1210
1233
)
1211
1234
1212
1235
step_emr_1 = EMRStep (
1213
1236
name = "emr-step-1" ,
1214
- cluster_id = "MyClusterID" ,
1215
- display_name = "emr-step-1 " ,
1237
+ cluster_id = emr_cluster_id ,
1238
+ display_name = "emr_step_1 " ,
1216
1239
description = "MyEMRStepDescription" ,
1217
- step_config = emr_step_config_1 ,
1240
+ step_config = emr_step_config ,
1218
1241
)
1219
1242
1220
- emr_step_config_2 = EMRStepConfig (jar = "s3:/script-runner/script-runner_2.jar" )
1221
-
1222
1243
step_emr_2 = EMRStep (
1223
1244
name = "emr-step-2" ,
1224
- cluster_id = "MyClusterID" ,
1225
- display_name = "emr-step-2 " ,
1245
+ cluster_id = step_emr_1 . properties . ClusterId ,
1246
+ display_name = "emr_step_2 " ,
1226
1247
description = "MyEMRStepDescription" ,
1227
- step_config = emr_step_config_2 ,
1248
+ step_config = emr_step_config ,
1228
1249
)
1229
1250
1230
1251
pipeline = Pipeline (
@@ -1237,20 +1258,24 @@ def test_two_steps_emr_pipeline_without_nullable_config_fields(
1237
1258
try :
1238
1259
response = pipeline .create (role )
1239
1260
create_arn = response ["PipelineArn" ]
1261
+ assert re .match (
1262
+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1263
+ )
1240
1264
1241
1265
execution = pipeline .start ()
1242
- response = execution .describe ()
1243
- assert response ["PipelineArn" ] == create_arn
1244
-
1245
1266
try :
1246
- execution .wait (delay = 60 , max_attempts = 10 )
1267
+ execution .wait (delay = 60 , max_attempts = 5 )
1247
1268
except WaiterError :
1248
1269
pass
1249
1270
1250
1271
execution_steps = execution .list_steps ()
1251
1272
assert len (execution_steps ) == 2
1252
1273
assert execution_steps [0 ]["StepName" ] == "emr-step-1"
1274
+ assert execution_steps [0 ].get ("FailureReason" , "" ) == ""
1275
+ assert execution_steps [0 ]["StepStatus" ] == "Succeeded"
1253
1276
assert execution_steps [1 ]["StepName" ] == "emr-step-2"
1277
+ assert execution_steps [1 ].get ("FailureReason" , "" ) == ""
1278
+ assert execution_steps [1 ]["StepStatus" ] == "Succeeded"
1254
1279
1255
1280
pipeline .parameters = [ParameterInteger (name = "InstanceCount" , default_value = 1 )]
1256
1281
response = pipeline .update (role )
0 commit comments