Skip to content

Commit 2f52467

Browse files
chenxyEthanShouhanCheng
authored andcommitted
feature: Add EMRStep support in Sagemaker pipeline
1 parent 8532613 commit 2f52467

File tree

3 files changed

+105
-72
lines changed

3 files changed

+105
-72
lines changed

src/sagemaker/workflow/properties.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,13 @@ def __init__(
7676
members = shape["members"]
7777
for key, info in members.items():
7878
if shapes.get(info["shape"], {}).get("type") == "list":
79-
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
80-
elif Properties._shapes.get(info["shape"], {}).get("type") == "map":
81-
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"])
79+
self.__dict__[key] = PropertiesList(
80+
f"{path}.{key}", info["shape"], service_name
81+
)
82+
elif shapes.get(info["shape"], {}).get("type") == "map":
83+
self.__dict__[key] = PropertiesMap(
84+
f"{path}.{key}", info["shape"], service_name
85+
)
8286
else:
8387
self.__dict__[key] = Properties(
8488
f"{path}.{key}", info["shape"], service_name=service_name
@@ -127,15 +131,17 @@ def __getitem__(self, item: Union[int, str]):
127131
class PropertiesMap(Properties):
128132
"""PropertiesMap for use in workflow expressions."""
129133

130-
def __init__(self, path: str, shape_name: str = None):
134+
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
131135
"""Create a PropertiesMap instance representing the given shape.
132136
133137
Args:
134138
path (str): The parent path of the PropertiesMap instance.
135139
shape_name (str): The botocore sagemaker service model shape name.
140+
service_name (str): The botocore service name.
136141
"""
137142
super(PropertiesMap, self).__init__(path, shape_name)
138143
self.shape_name = shape_name
144+
self.service_name = service_name
139145
self._items: Dict[Union[int, str], Properties] = dict()
140146

141147
def __getitem__(self, item: Union[int, str]):
@@ -145,7 +151,7 @@ def __getitem__(self, item: Union[int, str]):
145151
item (Union[int, str]): The index of the item in sequence.
146152
"""
147153
if item not in self._items.keys():
148-
shape = Properties._shapes.get(self.shape_name)
154+
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
149155
member = shape["value"]["shape"]
150156
if isinstance(item, str):
151157
property_item = Properties(f"{self._path}['{item}']", member)

tests/data/workflow/emr-script.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
echo "This is emr test script..."
2+
sleep 15

tests/integ/test_workflow.py

Lines changed: 92 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@
9595
from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum
9696
from tests.integ import DATA_DIR
9797
from tests.integ.kms_utils import get_or_create_kms_key
98+
<<<<<<< HEAD
9899
from tests.integ.retry import retries
100+
=======
101+
from tests.integ.vpc_test_utils import get_or_create_vpc_resources
102+
>>>>>>> 48cd0d8f (feature: Add EMRStep support in Sagemaker pipeline)
99103

100104

101105
def ordered(obj):
@@ -281,6 +285,75 @@ def build_jar():
281285
subprocess.run(["rm", os.path.join(jar_file_path, java_file_path, "HelloJavaSparkApp.class")])
282286

283287

288+
@pytest.fixture(scope="module")
289+
def emr_script_path(sagemaker_session):
290+
input_path = sagemaker_session.upload_data(
291+
path=os.path.join(DATA_DIR, "workflow", "emr-script.sh"),
292+
key_prefix="integ-test-data/workflow",
293+
)
294+
return input_path
295+
296+
297+
@pytest.fixture(scope="module")
298+
def emr_cluster_id(sagemaker_session, role):
299+
emr_client = sagemaker_session.boto_session.client("emr")
300+
cluster_name = "emr-step-test-cluster"
301+
cluster_id = get_existing_emr_cluster_id(emr_client, cluster_name)
302+
303+
if cluster_id is None:
304+
create_new_emr_cluster(sagemaker_session, emr_client, cluster_name)
305+
return cluster_id
306+
307+
308+
def get_existing_emr_cluster_id(emr_client, cluster_name):
309+
try:
310+
response = emr_client.list_clusters(ClusterStates=["RUNNING", "WAITING"])
311+
for cluster in response["Clusters"]:
312+
if cluster["Name"].startswith(cluster_name):
313+
cluster_id = cluster["Id"]
314+
print("Using existing cluster: {}".format(cluster_id))
315+
return cluster_id
316+
except Exception:
317+
raise
318+
319+
320+
def create_new_emr_cluster(sagemaker_session, emr_client, cluster_name):
321+
ec2_client = sagemaker_session.boto_session.client("ec2")
322+
subnet_ids, security_group_id = get_or_create_vpc_resources(ec2_client)
323+
try:
324+
response = emr_client.run_job_flow(
325+
Name="emr-step-test-cluster",
326+
LogUri="s3://{}/{}".format(sagemaker_session.default_bucket(), "emr-test-logs"),
327+
ReleaseLabel="emr-6.3.0",
328+
Applications=[
329+
{"Name": "Hadoop"},
330+
{"Name": "Spark"},
331+
],
332+
Instances={
333+
"InstanceGroups": [
334+
{
335+
"Name": "Master nodes",
336+
"Market": "ON_DEMAND",
337+
"InstanceRole": "MASTER",
338+
"InstanceType": "m4.large",
339+
"InstanceCount": 1,
340+
}
341+
],
342+
"KeepJobFlowAliveWhenNoSteps": True,
343+
"TerminationProtected": False,
344+
"Ec2SubnetId": subnet_ids[0],
345+
},
346+
VisibleToAllUsers=True,
347+
JobFlowRole="EMR_EC2_DefaultRole",
348+
ServiceRole="EMR_DefaultRole",
349+
)
350+
cluster_id = response["JobFlowId"]
351+
print("Created new cluster: {}".format(cluster_id))
352+
return cluster_id
353+
except Exception:
354+
raise
355+
356+
284357
def test_three_step_definition(
285358
sagemaker_session,
286359
region_name,
@@ -1149,82 +1222,30 @@ def test_two_step_lambda_pipeline_with_output_reference(
11491222
pass
11501223

11511224

1152-
def test_one_step_emr_pipeline(sagemaker_session, role, pipeline_name, region_name):
1153-
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
1154-
1155-
emr_step_config = EMRStepConfig(
1156-
jar="s3:/script-runner/script-runner.jar",
1157-
args=["--arg_0", "arg_0_value"],
1158-
main_class="com.my.main",
1159-
properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}],
1160-
)
1161-
1162-
step_emr = EMRStep(
1163-
name="emr-step",
1164-
cluster_id="MyClusterID",
1165-
display_name="emr_step",
1166-
description="MyEMRStepDescription",
1167-
step_config=emr_step_config,
1168-
)
1169-
1170-
pipeline = Pipeline(
1171-
name=pipeline_name,
1172-
parameters=[instance_count],
1173-
steps=[step_emr],
1174-
sagemaker_session=sagemaker_session,
1175-
)
1176-
1177-
try:
1178-
response = pipeline.create(role)
1179-
create_arn = response["PipelineArn"]
1180-
1181-
execution = pipeline.start()
1182-
response = execution.describe()
1183-
assert response["PipelineArn"] == create_arn
1184-
1185-
try:
1186-
execution.wait(delay=60, max_attempts=10)
1187-
except WaiterError:
1188-
pass
1189-
1190-
execution_steps = execution.list_steps()
1191-
assert len(execution_steps) == 1
1192-
assert execution_steps[0]["StepName"] == "emr-step"
1193-
finally:
1194-
try:
1195-
pipeline.delete()
1196-
except Exception:
1197-
pass
1198-
1199-
1200-
def test_two_steps_emr_pipeline_without_nullable_config_fields(
1201-
sagemaker_session, role, pipeline_name, region_name
1225+
def test_two_steps_emr_pipeline(
1226+
sagemaker_session, role, pipeline_name, region_name, emr_cluster_id, emr_script_path
12021227
):
12031228
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
12041229

1205-
emr_step_config_1 = EMRStepConfig(
1206-
jar="s3:/script-runner/script-runner_1.jar",
1207-
args=["--arg_0", "arg_0_value"],
1208-
main_class="com.my.main",
1209-
properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}],
1230+
emr_step_config = EMRStepConfig(
1231+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
1232+
args=[emr_script_path],
12101233
)
12111234

12121235
step_emr_1 = EMRStep(
12131236
name="emr-step-1",
1214-
cluster_id="MyClusterID",
1215-
display_name="emr-step-1",
1237+
cluster_id=emr_cluster_id,
1238+
display_name="emr_step_1",
12161239
description="MyEMRStepDescription",
1217-
step_config=emr_step_config_1,
1240+
step_config=emr_step_config,
12181241
)
12191242

1220-
emr_step_config_2 = EMRStepConfig(jar="s3:/script-runner/script-runner_2.jar")
1221-
12221243
step_emr_2 = EMRStep(
12231244
name="emr-step-2",
1224-
cluster_id="MyClusterID",
1225-
display_name="emr-step-2",
1245+
cluster_id=step_emr_1.properties.ClusterId,
1246+
display_name="emr_step_2",
12261247
description="MyEMRStepDescription",
1227-
step_config=emr_step_config_2,
1248+
step_config=emr_step_config,
12281249
)
12291250

12301251
pipeline = Pipeline(
@@ -1237,20 +1258,24 @@ def test_two_steps_emr_pipeline_without_nullable_config_fields(
12371258
try:
12381259
response = pipeline.create(role)
12391260
create_arn = response["PipelineArn"]
1261+
assert re.match(
1262+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn
1263+
)
12401264

12411265
execution = pipeline.start()
1242-
response = execution.describe()
1243-
assert response["PipelineArn"] == create_arn
1244-
12451266
try:
1246-
execution.wait(delay=60, max_attempts=10)
1267+
execution.wait(delay=60, max_attempts=5)
12471268
except WaiterError:
12481269
pass
12491270

12501271
execution_steps = execution.list_steps()
12511272
assert len(execution_steps) == 2
12521273
assert execution_steps[0]["StepName"] == "emr-step-1"
1274+
assert execution_steps[0].get("FailureReason", "") == ""
1275+
assert execution_steps[0]["StepStatus"] == "Succeeded"
12531276
assert execution_steps[1]["StepName"] == "emr-step-2"
1277+
assert execution_steps[1].get("FailureReason", "") == ""
1278+
assert execution_steps[1]["StepStatus"] == "Succeeded"
12541279

12551280
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
12561281
response = pipeline.update(role)

0 commit comments

Comments
 (0)