Skip to content

feature: EMR step runtime role support #3703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/sagemaker/workflow/emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ def to_request(self) -> RequestType:
"must be explicitly set to None."
)

ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID = (
"EMRStep {step_name} cannot have execution_role_arn"
"without cluster_id."
"To use EMRStep with "
"execution_role_arn, cluster_id "
"must not be None."
)

ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG = (
"EMRStep {step_name} must have either cluster_id or cluster_config"
)
Expand Down Expand Up @@ -155,6 +163,7 @@ def __init__(
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
cache_config: CacheConfig = None,
cluster_config: Dict[str, Any] = None,
execution_role_arn: str = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class init method is public facing, which means once we release this change, it's not able to update the parameters here until next big version bump up, e.g. v2 -> v3.
So, have we considered why this is a good place to add this execution_role_arn arg? Or in other words, why execution_role_arn is not appropriate to be added in step_config.

Also let's add @aoguo64 to review as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

step_config is mapped to StepConfig in EMR's AddJobFlowSteps, it's not a single-step configuration but applies to all steps to run in a single AddJobFlowSteps (though Pipelines does only run one job step per EMR step).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qidewenwhen are you suggesting using kwargs intead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then step_config is not a good place to go, in terms of execution_role_arn.

I'm just hoping to raise discussion and caution on this public facing interface change as it's not revertible.
And I do have a concern that if EMR side keeps adding new args, our EMRStep interface will keep expanding and eventually it may be hard to use, see example of the obsoleted RegisterModel interface which contains ~30 args: https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/workflow/step_collections.py#L57-L94

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So **kwargs seems to be able to resolve my concern above but usually it may not be a good practice to include **kwargs.
Thus I'd suggest to have thorough consideration on this step interface and have more eyes on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced up offline: there's no better approach to avoid the call out issue. We will introduce a step_arg parameter to group all EMRStep argument related parameters in the future when the step interface grows too big.

):
"""Constructs an `EMRStep`.

Expand Down Expand Up @@ -185,7 +194,11 @@ def __init__(
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html.
Note that if you want to use ``cluster_config``, then you have to set
``cluster_id`` as None.

execution_role_arn(str): The ARN of the runtime role assumed by this `EMRStep`. The
job submitted to your EMR cluster uses this role to access AWS resources. This
value is passed as ExecutionRoleArn to the AddJobFlowSteps request (an EMR request)
called on the cluster specified by ``cluster_id``, so you can only include this
field if ``cluster_id`` is not None.
"""
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)

Expand All @@ -198,9 +211,18 @@ def __init__(
if cluster_id is not None and cluster_config is not None:
raise ValueError(ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))

if execution_role_arn is not None and cluster_id is None:
raise ValueError(
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID.format(step_name=name)
)

if cluster_id is not None:
emr_step_args["ClusterId"] = cluster_id
root_property.__dict__["ClusterId"] = cluster_id

if execution_role_arn is not None:
emr_step_args["ExecutionRoleArn"] = execution_role_arn
root_property.__dict__["ExecutionRoleArn"] = execution_role_arn
elif cluster_config is not None:
self._validate_cluster_config(cluster_config, name)
emr_step_args["ClusterConfig"] = cluster_config
Expand Down
38 changes: 36 additions & 2 deletions tests/unit/sagemaker/workflow/test_emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS,
ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG,
ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG,
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID,
)
from sagemaker.workflow.steps import CacheConfig
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
from sagemaker.workflow.parameters import ParameterString
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered


def test_emr_step_with_one_step_config(sagemaker_session):
@pytest.mark.parametrize("execution_role_arn", [None, "arn:aws:iam:000000000000:role/runtime-role"])
def test_emr_step_with_one_step_config(sagemaker_session, execution_role_arn):
emr_step_config = EMRStepConfig(
jar="s3:/script-runner/script-runner.jar",
args=["--arg_0", "arg_0_value"],
Expand All @@ -47,9 +49,11 @@ def test_emr_step_with_one_step_config(sagemaker_session):
step_config=emr_step_config,
depends_on=["TestStep"],
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
execution_role_arn=execution_role_arn,
)
emr_step.add_depends_on(["SecondTestStep"])
assert emr_step.to_request() == {

expected_request = {
"Name": "MyEMRStep",
"Type": "EMR",
"Arguments": {
Expand All @@ -72,7 +76,16 @@ def test_emr_step_with_one_step_config(sagemaker_session):
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
}

if execution_role_arn is not None:
expected_request["Arguments"]["ExecutionRoleArn"] = execution_role_arn

assert emr_step.to_request() == expected_request
assert emr_step.properties.ClusterId == "MyClusterID"
assert (
emr_step.properties.ExecutionRoleArn == execution_role_arn
if execution_role_arn is not None
else True
)
assert emr_step.properties.ActionOnFailure.expr == {"Get": "Steps.MyEMRStep.ActionOnFailure"}
assert emr_step.properties.Config.Args.expr == {"Get": "Steps.MyEMRStep.Config.Args"}
assert emr_step.properties.Config.Jar.expr == {"Get": "Steps.MyEMRStep.Config.Jar"}
Expand Down Expand Up @@ -239,6 +252,27 @@ def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_n
assert actual_error_msg == expected_error_msg


def test_emr_step_throws_exception_when_both_execution_role_arn_and_cluster_config_are_present():
with pytest.raises(ValueError) as exceptionInfo:
EMRStep(
name=g_emr_step_name,
display_name="MyEMRStep",
description="MyEMRStepDescription",
step_config=g_emr_step_config,
cluster_id=None,
cluster_config=g_cluster_config,
depends_on=["TestStep"],
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
execution_role_arn="arn:aws:iam:000000000000:role/some-role",
)
expected_error_msg = ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID.format(
step_name=g_emr_step_name
)
actual_error_msg = exceptionInfo.value.args[0]

assert actual_error_msg == expected_error_msg


def test_emr_step_with_valid_cluster_config():
emr_step = EMRStep(
name=g_emr_step_name,
Expand Down