Skip to content

feature: Add EMRStep support in Sagemaker pipeline #2848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions src/sagemaker/workflow/emr_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List

from sagemaker.workflow.entities import (
RequestType,
)
from sagemaker.workflow.properties import (
Properties,
)
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig


class EMRStepConfig:
"""Config for a Hadoop Jar step."""

def __init__(
self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None
):
"""Create a definition for input data used by an EMR cluster(job flow) step.

See AWS documentation on the ``StepConfig`` API for more details on the parameters.

Args:
args(List[str]):
A list of command line arguments passed to
the JAR file's main function when executed.
jar(str): A path to a JAR file run during the step.
main_class(str): The name of the main class in the specified Java file.
properties(List(dict)): A list of key-value pairs that are set when the step runs.
"""
self.jar = jar
self.args = args
self.main_class = main_class
self.properties = properties

def to_request(self) -> RequestType:
"""Convert EMRStepConfig object to request dict."""
config = {"HadoopJarStep": {"Jar": self.jar}}
if self.args is not None:
config["HadoopJarStep"]["Args"] = self.args
if self.main_class is not None:
config["HadoopJarStep"]["MainClass"] = self.main_class
if self.properties is not None:
config["HadoopJarStep"]["Properties"] = self.properties

return config


class EMRStep(Step):
"""EMR step for workflow."""

def __init__(
self,
name: str,
display_name: str,
description: str,
cluster_id: str,
step_config: EMRStepConfig,
depends_on: List[str] = None,
cache_config: CacheConfig = None,
):
"""Constructs a EMRStep.

Args:
name(str): The name of the EMR step.
display_name(str): The display name of the EMR step.
description(str): The description of the EMR step.
cluster_id(str): The ID of the running EMR cluster.
step_config(EMRStepConfig): One StepConfig to be executed by the job flow.
depends_on(List[str]):
A list of step names this `sagemaker.workflow.steps.EMRStep` depends on
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.

"""
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)

emr_step_args = {"ClusterId": cluster_id, "StepConfig": step_config.to_request()}
self.args = emr_step_args
self.cache_config = cache_config

root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr")
root_property.__dict__["ClusterId"] = cluster_id
self._properties = root_property

@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call `AddJobFlowSteps`.

NOTE: The AddFlowJobSteps request is not quite the args list that workflow needs.
The Name attribute in AddJobFlowSteps cannot be passed; it will be set during runtime.
In addition to that, we will also need to include emr job inputs and output config.
"""
return self.args

@property
def properties(self) -> RequestType:
"""A Properties object representing the EMR DescribeStepResponse model"""
return self._properties

def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
return request_dict
59 changes: 39 additions & 20 deletions src/sagemaker/workflow/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,24 @@


class PropertiesMeta(type):
"""Load an internal shapes attribute from the botocore sagemaker service model."""
"""Load an internal shapes attribute from the botocore service model

_shapes = None
for sagemaker and emr service.
"""

_shapes_map = dict()
_primitive_types = {"string", "boolean", "integer", "float"}

def __new__(mcs, *args, **kwargs):
"""Loads up the shapes from the botocore sagemaker service model."""
if mcs._shapes is None:
"""Loads up the shapes from the botocore service model."""
if len(mcs._shapes_map.keys()) == 0:
loader = botocore.loaders.Loader()
model = loader.load_service_model("sagemaker", "service-2")
mcs._shapes = model["shapes"]

sagemaker_model = loader.load_service_model("sagemaker", "service-2")
emr_model = loader.load_service_model("emr", "service-2")
mcs._shapes_map["sagemaker"] = sagemaker_model["shapes"]
mcs._shapes_map["emr"] = emr_model["shapes"]

return super().__new__(mcs, *args, **kwargs)


Expand All @@ -45,32 +52,41 @@ def __init__(
path: str,
shape_name: str = None,
shape_names: List[str] = None,
service_name: str = "sagemaker",
):
"""Create a Properties instance representing the given shape.

Args:
path (str): The parent path of the Properties instance.
shape_name (str): The botocore sagemaker service model shape name.
shape_names (str): A List of the botocore sagemaker service model shape name.
shape_name (str): The botocore service model shape name.
shape_names (str): A List of the botocore service model shape name.
"""
self._path = path
shape_names = [] if shape_names is None else shape_names
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names

shapes = Properties._shapes_map.get(service_name, {})

for name in self._shape_names:
shape = Properties._shapes.get(name, {})
shape = shapes.get(name, {})
shape_type = shape.get("type")
if shape_type in Properties._primitive_types:
self.__str__ = name
elif shape_type == "structure":
members = shape["members"]
for key, info in members.items():
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
elif Properties._shapes.get(info["shape"], {}).get("type") == "map":
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"])
if shapes.get(info["shape"], {}).get("type") == "list":
self.__dict__[key] = PropertiesList(
f"{path}.{key}", info["shape"], service_name
)
elif shapes.get(info["shape"], {}).get("type") == "map":
self.__dict__[key] = PropertiesMap(
f"{path}.{key}", info["shape"], service_name
)
else:
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])
self.__dict__[key] = Properties(
f"{path}.{key}", info["shape"], service_name=service_name
)

@property
def expr(self):
Expand All @@ -81,16 +97,17 @@ def expr(self):
class PropertiesList(Properties):
"""PropertiesList for use in workflow expressions."""

def __init__(self, path: str, shape_name: str = None):
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
"""Create a PropertiesList instance representing the given shape.

Args:
path (str): The parent path of the PropertiesList instance.
shape_name (str): The botocore sagemaker service model shape name.
root_shape_name (str): The botocore sagemaker service model shape name.
shape_name (str): The botocore service model shape name.
service_name (str): The botocore service name.
"""
super(PropertiesList, self).__init__(path, shape_name)
self.shape_name = shape_name
self.service_name = service_name
self._items: Dict[Union[int, str], Properties] = dict()

def __getitem__(self, item: Union[int, str]):
Expand All @@ -100,7 +117,7 @@ def __getitem__(self, item: Union[int, str]):
item (Union[int, str]): The index of the item in sequence.
"""
if item not in self._items.keys():
shape = Properties._shapes.get(self.shape_name)
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
member = shape["member"]["shape"]
if isinstance(item, str):
property_item = Properties(f"{self._path}['{item}']", member)
Expand All @@ -114,15 +131,17 @@ def __getitem__(self, item: Union[int, str]):
class PropertiesMap(Properties):
"""PropertiesMap for use in workflow expressions."""

def __init__(self, path: str, shape_name: str = None):
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
"""Create a PropertiesMap instance representing the given shape.

Args:
path (str): The parent path of the PropertiesMap instance.
shape_name (str): The botocore sagemaker service model shape name.
service_name (str): The botocore service name.
"""
super(PropertiesMap, self).__init__(path, shape_name)
self.shape_name = shape_name
self.service_name = service_name
self._items: Dict[Union[int, str], Properties] = dict()

def __getitem__(self, item: Union[int, str]):
Expand All @@ -132,7 +151,7 @@ def __getitem__(self, item: Union[int, str]):
item (Union[int, str]): The index of the item in sequence.
"""
if item not in self._items.keys():
shape = Properties._shapes.get(self.shape_name)
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
member = shape["value"]["shape"]
if isinstance(item, str):
property_item = Properties(f"{self._path}['{item}']", member)
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
LAMBDA = "Lambda"
QUALITY_CHECK = "QualityCheck"
CLARIFY_CHECK = "ClarifyCheck"
EMR = "EMR"


@attr.s
Expand Down
2 changes: 2 additions & 0 deletions tests/data/workflow/emr-script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
echo "This is emr test script..."
sleep 15
45 changes: 45 additions & 0 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig
from sagemaker.wrangler.processing import DataWranglerProcessor
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
from sagemaker.workflow.execution_variables import ExecutionVariables
Expand Down Expand Up @@ -1148,6 +1149,50 @@ def test_two_step_lambda_pipeline_with_output_reference(
pass


def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_name):
instance_count = ParameterInteger(name="InstanceCount", default_value=2)

emr_step_config = EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
)

step_emr_1 = EMRStep(
name="emr-step-1",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_1",
description="MyEMRStepDescription",
step_config=emr_step_config,
)

step_emr_2 = EMRStep(
name="emr-step-2",
cluster_id=step_emr_1.properties.ClusterId,
display_name="emr_step_2",
description="MyEMRStepDescription",
step_config=emr_step_config,
)

pipeline = Pipeline(
name=pipeline_name,
parameters=[instance_count],
steps=[step_emr_1, step_emr_2],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_conditional_pytorch_training_model_registration(
sagemaker_session,
role,
Expand Down
Loading