Skip to content

feature: support cluster lifecycle management for Sagemaker EMR step #3604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/workflows/pipelines/sagemaker.workflow.pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,8 @@ Steps

.. autoclass:: sagemaker.workflow.fail_step.FailStep

.. autoclass:: sagemaker.workflow.emr_step.EMRStepConfig

.. autoclass:: sagemaker.workflow.emr_step.EMRStep

.. autoclass:: sagemaker.workflow.automl_step.AutoMLStep
126 changes: 119 additions & 7 deletions src/sagemaker/workflow/emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List, Union, Optional
from typing import Any, Dict, List, Union, Optional

from sagemaker.workflow.entities import (
RequestType,
Expand All @@ -33,7 +33,8 @@ def __init__(
):
"""Create a definition for input data used by an EMR cluster(job flow) step.

See AWS documentation on the ``StepConfig`` API for more details on the parameters.
See AWS documentation for more information about the `StepConfig
<https://docs.aws.amazon.com/emr/latest/APIReference/API_StepConfig.html>`_ API parameters.

Args:
args(List[str]):
Expand Down Expand Up @@ -61,9 +62,89 @@ def to_request(self) -> RequestType:
return config


INSTANCES = "Instances"
INSTANCEGROUPS = "InstanceGroups"
INSTANCEFLEETS = "InstanceFleets"
ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS = (
"In EMRStep {step_name}, cluster_config "
"should not contain any of the Name, "
"AutoTerminationPolicy and/or Steps."
)

ERR_STR_WITHOUT_INSTANCE = "In EMRStep {step_name}, cluster_config must contain " + INSTANCES + "."

ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED = (
"In EMRStep {step_name}, " + INSTANCES + " should not contain "
"KeepJobFlowAliveWhenNoSteps or "
"TerminationProtected."
)

ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS = (
"In EMRStep {step_name}, "
+ INSTANCES
+ " should contain either "
+ INSTANCEGROUPS
+ " or "
+ INSTANCEFLEETS
+ "."
)

ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG = (
"EMRStep {step_name} can not have both cluster_id"
"or cluster_config."
"To use EMRStep with "
"cluster_config, cluster_id "
"must be explicitly set to None."
)

ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG = (
"EMRStep {step_name} must have either cluster_id or cluster_config"
)


class EMRStep(Step):
"""EMR step for workflow."""

def _validate_cluster_config(self, cluster_config, step_name):
"""Validates user provided cluster_config.

Args:
cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
user provided cluster configuration.
step_name: The name of the EMR step.
"""

if (
"Name" in cluster_config
or "AutoTerminationPolicy" in cluster_config
or "Steps" in cluster_config
):
raise ValueError(
ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS.format(step_name=step_name)
)

if INSTANCES not in cluster_config:
raise ValueError(ERR_STR_WITHOUT_INSTANCE.format(step_name=step_name))

if (
"KeepJobFlowAliveWhenNoSteps" in cluster_config[INSTANCES]
or "TerminationProtected" in cluster_config[INSTANCES]
):
raise ValueError(
ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED.format(step_name=step_name)
)

if (
INSTANCEGROUPS in cluster_config[INSTANCES]
and INSTANCEFLEETS in cluster_config[INSTANCES]
) or (
INSTANCEGROUPS not in cluster_config[INSTANCES]
and INSTANCEFLEETS not in cluster_config[INSTANCES]
):
raise ValueError(
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS.format(step_name=step_name)
)

def __init__(
self,
name: str,
Expand All @@ -73,8 +154,9 @@ def __init__(
step_config: EMRStepConfig,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
cache_config: CacheConfig = None,
cluster_config: Dict[str, Any] = None,
):
"""Constructs a EMRStep.
"""Constructs an `EMRStep`.

Args:
name(str): The name of the EMR step.
Expand All @@ -86,16 +168,46 @@ def __init__(
names or `Step` instances or `StepCollection` instances that this `EMRStep`
depends on.
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
cluster_config(Dict[str, Any]): The recipe of the
EMR cluster, passed as a dictionary.
The elements are defined in the request syntax for `RunJobFlow`.
However, the following elements are not recognized as part of the cluster
configuration and you should not include them in the dictionary:

* ``cluster_config[Name]``
* ``cluster_config[Steps]``
* ``cluster_config[AutoTerminationPolicy]``
* ``cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]``
* ``cluster_config[Instances][TerminationProtected]``

For more information about the fields you can include in your cluster
configuration, see
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html.
Note that if you want to use ``cluster_config``, then you have to set
``cluster_id`` as None.

"""
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)

emr_step_args = {"ClusterId": cluster_id, "StepConfig": step_config.to_request()}
emr_step_args = {"StepConfig": step_config.to_request()}
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")

if cluster_id is None and cluster_config is None:
raise ValueError(ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))

if cluster_id is not None and cluster_config is not None:
raise ValueError(ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))

if cluster_id is not None:
emr_step_args["ClusterId"] = cluster_id
root_property.__dict__["ClusterId"] = cluster_id
elif cluster_config is not None:
self._validate_cluster_config(cluster_config, name)
emr_step_args["ClusterConfig"] = cluster_config
root_property.__dict__["ClusterConfig"] = cluster_config

self.args = emr_step_args
self.cache_config = cache_config

root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
root_property.__dict__["ClusterId"] = cluster_id
self._properties = root_property

@property
Expand Down
54 changes: 54 additions & 0 deletions tests/integ/sagemaker/workflow/test_emr_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,57 @@ def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_n
pipeline.delete()
except Exception:
pass


def test_emr_with_cluster_config(sagemaker_session, role, pipeline_name, region_name):

emr_step_config = EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
)

cluster_config = {
"Instances": {
"InstanceGroups": [
{
"Name": "Master Instance Group",
"InstanceRole": "MASTER",
"InstanceCount": 1,
"InstanceType": "m1.small",
"Market": "ON_DEMAND",
}
],
"InstanceCount": 1,
"HadoopVersion": "MyHadoopVersion",
},
"AmiVersion": "3.8.0",
"AdditionalInfo": "MyAdditionalInfo",
}

step_emr_with_cluster_config = EMRStep(
name="MyEMRStep-name",
display_name="MyEMRStep-display_name",
description="MyEMRStepDescription",
cluster_id=None,
step_config=emr_step_config,
cluster_config=cluster_config,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr_with_cluster_config],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass
Loading