Skip to content

feat: Support selective pipeline execution for function step #4372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions src/sagemaker/remote_function/core/pipeline_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from sagemaker.s3 import s3_path_join
from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3
from sagemaker.workflow.step_outputs import get_step


@dataclass
Expand Down Expand Up @@ -92,7 +93,7 @@ class _S3BaseUriIdentifier:
class _DelayedReturn:
"""Delayed return from a function."""

uri: List[Union[str, _Parameter, _ExecutionVariable]]
uri: Union[_Properties, List[Union[str, _Parameter, _ExecutionVariable]]]
reference_path: Tuple = field(default_factory=tuple)


Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(
self,
delayed_returns: List[_DelayedReturn],
hmac_key: str,
properties_resolver: _PropertiesResolver,
parameter_resolver: _ParameterResolver,
execution_variable_resolver: _ExecutionVariableResolver,
s3_base_uri: str,
Expand All @@ -174,6 +176,7 @@ def __init__(
Args:
delayed_returns: list of delayed returns to resolve.
hmac_key: key used to encrypt serialized and deserialized function and arguments.
properties_resolver: resolver used to resolve step properties.
parameter_resolver: resolver used to pipeline parameters.
execution_variable_resolver: resolver used to resolve execution variables.
s3_base_uri (str): the s3 base uri of the function step that
Expand All @@ -184,6 +187,7 @@ def __init__(
self._s3_base_uri = s3_base_uri
self._parameter_resolver = parameter_resolver
self._execution_variable_resolver = execution_variable_resolver
self._properties_resolver = properties_resolver
# different delayed returns can have the same uri, so we need to dedupe
uris = {
self._resolve_delayed_return_uri(delayed_return) for delayed_return in delayed_returns
Expand Down Expand Up @@ -214,7 +218,10 @@ def resolve(self, delayed_return: _DelayedReturn) -> Any:

def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
"""Resolve the s3 uri of the delayed return."""
if isinstance(delayed_return.uri, _Properties):
return self._properties_resolver.resolve(delayed_return.uri)

# Keep the following old resolution logics to keep backward compatible
uri = []
for component in delayed_return.uri:
if isinstance(component, _Parameter):
Expand Down Expand Up @@ -274,6 +281,7 @@ def resolve_pipeline_variables(
delayed_return_resolver = _DelayedReturnResolver(
delayed_returns=delayed_returns,
hmac_key=hmac_key,
properties_resolver=properties_resolver,
parameter_resolver=parameter_resolver,
execution_variable_resolver=execution_variable_resolver,
s3_base_uri=s3_base_uri,
Expand Down Expand Up @@ -325,27 +333,12 @@ def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict

from sagemaker.workflow.entities import PipelineVariable

from sagemaker.workflow.execution_variables import ExecutionVariables

from sagemaker.workflow.function_step import DelayedReturn

# Notes:
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
# when defining function steps. After step-level arg serialization,
# it's hard to update the s3_base_uri in pipeline compile time.
# Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
# the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
# should be retrieved from the pipeline's sagemaker_session.
def convert(arg):
if isinstance(arg, DelayedReturn):
return _DelayedReturn(
uri=[
_S3BaseUriIdentifier(),
ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable,
arg._step.name,
"results",
],
uri=get_step(arg)._properties.OutputDataConfig.S3OutputPath._pickleable,
reference_path=arg._reference_path,
)

Expand Down
23 changes: 22 additions & 1 deletion src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from sagemaker import vpc_utils
from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData
from sagemaker.remote_function.core.pipeline_variables import Context

from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
RuntimeEnvironmentManager,
_DependencySettings,
Expand All @@ -72,6 +73,8 @@
copy_workdir,
resolve_custom_file_filter_from_config_file,
)
from sagemaker.workflow.function_step import DelayedReturn
from sagemaker.workflow.step_outputs import get_step

if TYPE_CHECKING:
from sagemaker.workflow.entities import PipelineVariable
Expand Down Expand Up @@ -701,6 +704,7 @@ def compile(
"""Build the artifacts and generate the training job request."""
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.functions import Join
from sagemaker.workflow.execution_variables import ExecutionVariables, ExecutionVariable
from sagemaker.workflow.utilities import load_step_compilation_context

Expand Down Expand Up @@ -760,7 +764,19 @@ def compile(
job_settings=job_settings, s3_base_uri=s3_base_uri
)

output_config = {"S3OutputPath": s3_base_uri}
if step_compilation_context:
s3_output_path = Join(
on="/",
values=[
s3_base_uri,
ExecutionVariables.PIPELINE_EXECUTION_ID,
step_compilation_context.step_name,
"results",
],
)
output_config = {"S3OutputPath": s3_output_path}
else:
output_config = {"S3OutputPath": s3_base_uri}
if job_settings.s3_kms_key is not None:
output_config["KmsKeyId"] = job_settings.s3_kms_key
request_dict["OutputDataConfig"] = output_config
Expand Down Expand Up @@ -804,6 +820,11 @@ def compile(
if isinstance(arg, (Parameter, ExecutionVariable, Properties)):
container_args.extend([arg.expr["Get"], arg.to_string()])

if isinstance(arg, DelayedReturn):
# The uri is a Properties object
uri = get_step(arg)._properties.OutputDataConfig.S3OutputPath
container_args.extend([uri.expr["Get"], uri.to_string()])

if run_info is not None:
container_args.extend(["--run_in_context", json.dumps(dataclasses.asdict(run_info))])
elif _RunContext.get_current_run() is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _is_file_exists(self, dependencies):

def _install_requirements_txt(self, local_path, python_executable):
"""Install requirements.txt file"""
cmd = f"{python_executable} -m pip install -r {local_path}"
cmd = f"{python_executable} -m pip install -r {local_path} -U"
logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd())
_run_shell_cmd(cmd)
logger.info("Command %s ran successfully", cmd)
Expand All @@ -268,7 +268,7 @@ def _create_conda_env(self, env_name, local_path):
def _install_req_txt_in_conda_env(self, env_name, local_path):
"""Install requirements.txt in the given conda environment"""

cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path}"
cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U"
logger.info("Activating conda env and installing requirements: %s", cmd)
_run_shell_cmd(cmd)
logger.info("Requirements installed successfully in conda env %s", env_name)
Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/workflow/function_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)

from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.retry import RetryPolicy
from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
from sagemaker.workflow.step_collections import StepCollection
Expand Down Expand Up @@ -101,6 +102,12 @@ def __init__(

self.__job_settings = None

# It's for internal usage to retrieve execution id from the properties.
# However, we won't expose the properties of function step to customers.
self._properties = Properties(
step_name=name, step=self, shape_name="DescribeTrainingJobResponse"
)

(
self._converted_func_args,
self._converted_func_kwargs,
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,11 +1039,19 @@ def get_function_step_result(
raise ValueError(_ERROR_MSG_OF_WRONG_STEP_TYPE)
s3_output_path = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]

s3_uri_suffix = s3_path_join(execution_id, step_name, RESULTS_FOLDER)
if s3_output_path.endswith(s3_uri_suffix) or s3_output_path[0:-1].endswith(s3_uri_suffix):
s3_uri = s3_output_path
else:
# This is the obsoleted version of s3_output_path
# Keeping it for backward compatible
s3_uri = s3_path_join(s3_output_path, s3_uri_suffix)

job_status = describe_training_job_response["TrainingJobStatus"]
if job_status == "Completed":
return deserialize_obj_from_s3(
sagemaker_session=sagemaker_session,
s3_uri=s3_path_join(s3_output_path, execution_id, step_name, RESULTS_FOLDER),
s3_uri=s3_uri,
hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
)

Expand Down
34 changes: 25 additions & 9 deletions tests/integ/sagemaker/workflow/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,24 @@ def create_and_execute_pipeline(
step_result_type=None,
step_result_value=None,
wait_duration=400, # seconds
selective_execution_config=None,
):
response = pipeline.create(role)

create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
create_arn = None
if not selective_execution_config:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)

execution = pipeline.start(
parameters=execution_parameters, selective_execution_config=selective_execution_config
)

execution = pipeline.start(parameters=execution_parameters)
response = execution.describe()
assert response["PipelineArn"] == create_arn
if create_arn:
response = execution.describe()
assert response["PipelineArn"] == create_arn

wait_pipeline_execution(execution=execution, delay=20, max_attempts=int(wait_duration / 20))

Expand All @@ -71,6 +77,16 @@ def create_and_execute_pipeline(
if step_result_value:
result = execution.result(execution_steps[0]["StepName"])
assert result == step_result_value, f"Expected {step_result_value}, instead found {result}"

if selective_execution_config:
for exe_step in execution_steps:
if exe_step["StepName"] in selective_execution_config.selected_steps:
continue
assert (
exe_step["SelectiveExecutionResult"]["SourcePipelineExecutionArn"]
== selective_execution_config.source_pipeline_execution_arn
)

return execution, execution_steps


Expand Down
Loading