Skip to content

Commit c50e3b8

Browse files
committed
feat: Support selective pipeline execution for function step
1 parent f2b47ab commit c50e3b8

File tree

11 files changed

+205
-155
lines changed

11 files changed

+205
-155
lines changed

src/sagemaker/local/entities.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,6 @@ def result(self, step_name: str):
765765
return get_function_step_result(
766766
step_name=step_name,
767767
step_list=self.list_steps()["PipelineExecutionSteps"],
768-
execution_id=self.pipeline_execution_name,
769768
sagemaker_session=self.local_session,
770769
)
771770

src/sagemaker/remote_function/core/pipeline_variables.py

Lines changed: 25 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
from concurrent.futures import ThreadPoolExecutor
1717
from dataclasses import dataclass, field
18-
from typing import Any, Union, Dict, List, Tuple
18+
from typing import Any, Dict, List, Tuple
1919

20-
from sagemaker.s3 import s3_path_join
2120
from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3
21+
from sagemaker.workflow.step_outputs import get_step
2222

2323

2424
@dataclass
@@ -77,22 +77,11 @@ class _ExecutionVariable:
7777
name: str
7878

7979

80-
@dataclass
81-
class _S3BaseUriIdentifier:
82-
"""Identifies that the class refers to function step s3 base uri.
83-
84-
The s3_base_uri = s3_root_uri + pipeline_name.
85-
This identifier is resolved in function step runtime by SDK.
86-
"""
87-
88-
NAME = "S3_BASE_URI"
89-
90-
9180
@dataclass
9281
class _DelayedReturn:
9382
"""Delayed return from a function."""
9483

95-
uri: List[Union[str, _Parameter, _ExecutionVariable]]
84+
uri: _Properties
9685
reference_path: Tuple = field(default_factory=tuple)
9786

9887

@@ -164,26 +153,18 @@ def __init__(
164153
self,
165154
delayed_returns: List[_DelayedReturn],
166155
hmac_key: str,
167-
parameter_resolver: _ParameterResolver,
168-
execution_variable_resolver: _ExecutionVariableResolver,
169-
s3_base_uri: str,
156+
properties_resolver: _PropertiesResolver,
170157
**settings,
171158
):
172159
"""Resolve delayed return.
173160
174161
Args:
175162
delayed_returns: list of delayed returns to resolve.
176163
hmac_key: key used to encrypt serialized and deserialized function and arguments.
177-
parameter_resolver: resolver used to pipeline parameters.
178-
execution_variable_resolver: resolver used to resolve execution variables.
179-
s3_base_uri (str): the s3 base uri of the function step that
180-
the serialized artifacts will be uploaded to.
181-
The s3_base_uri = s3_root_uri + pipeline_name.
164+
properties_resolver: resolver used to resolve step properties.
182165
**settings: settings to pass to the deserialization function.
183166
"""
184-
self._s3_base_uri = s3_base_uri
185-
self._parameter_resolver = parameter_resolver
186-
self._execution_variable_resolver = execution_variable_resolver
167+
self._properties_resolver = properties_resolver
187168
# different delayed returns can have the same uri, so we need to dedupe
188169
uris = {
189170
self._resolve_delayed_return_uri(delayed_return) for delayed_return in delayed_returns
@@ -214,18 +195,7 @@ def resolve(self, delayed_return: _DelayedReturn) -> Any:
214195

215196
def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
216197
"""Resolve the s3 uri of the delayed return."""
217-
218-
uri = []
219-
for component in delayed_return.uri:
220-
if isinstance(component, _Parameter):
221-
uri.append(self._parameter_resolver.resolve(component))
222-
elif isinstance(component, _ExecutionVariable):
223-
uri.append(self._execution_variable_resolver.resolve(component))
224-
elif isinstance(component, _S3BaseUriIdentifier):
225-
uri.append(self._s3_base_uri)
226-
else:
227-
uri.append(component)
228-
return s3_path_join(*uri)
198+
return self._properties_resolver.resolve(delayed_return.uri)
229199

230200

231201
def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any):
@@ -241,7 +211,6 @@ def resolve_pipeline_variables(
241211
func_args: Tuple,
242212
func_kwargs: Dict,
243213
hmac_key: str,
244-
s3_base_uri: str,
245214
**settings,
246215
):
247216
"""Resolve pipeline variables.
@@ -251,8 +220,6 @@ def resolve_pipeline_variables(
251220
func_args: function args.
252221
func_kwargs: function kwargs.
253222
hmac_key: key used to encrypt serialized and deserialized function and arguments.
254-
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
255-
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
256223
**settings: settings to pass to the deserialization function.
257224
"""
258225

@@ -274,9 +241,7 @@ def resolve_pipeline_variables(
274241
delayed_return_resolver = _DelayedReturnResolver(
275242
delayed_returns=delayed_returns,
276243
hmac_key=hmac_key,
277-
parameter_resolver=parameter_resolver,
278-
execution_variable_resolver=execution_variable_resolver,
279-
s3_base_uri=s3_base_uri,
244+
properties_resolver=properties_resolver,
280245
**settings,
281246
)
282247

@@ -322,39 +287,27 @@ def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict
322287
func_args: function args.
323288
func_kwargs: function kwargs.
324289
"""
290+
converted_func_args = tuple(_convert_pipeline_variable_to_pickleable(arg) for arg in func_args)
291+
converted_func_kwargs = {
292+
key: _convert_pipeline_variable_to_pickleable(arg) for key, arg in func_kwargs.items()
293+
}
325294

326-
from sagemaker.workflow.entities import PipelineVariable
327-
328-
from sagemaker.workflow.execution_variables import ExecutionVariables
295+
return converted_func_args, converted_func_kwargs
329296

330-
from sagemaker.workflow.function_step import DelayedReturn
331297

332-
# Notes:
333-
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334-
# when defining function steps. After step-level arg serialization,
335-
# it's hard to update the s3_base_uri in pipeline compile time.
336-
# Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337-
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338-
# the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339-
# should be retrieved from the pipeline's sagemaker_session.
340-
def convert(arg):
341-
if isinstance(arg, DelayedReturn):
342-
return _DelayedReturn(
343-
uri=[
344-
_S3BaseUriIdentifier(),
345-
ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable,
346-
arg._step.name,
347-
"results",
348-
],
349-
reference_path=arg._reference_path,
350-
)
298+
def _convert_pipeline_variable_to_pickleable(arg):
299+
"""Convert a pipeline variable to pickleable."""
300+
from sagemaker.workflow.entities import PipelineVariable
351301

352-
if isinstance(arg, PipelineVariable):
353-
return arg._pickleable
302+
from sagemaker.workflow.function_step import DelayedReturn
354303

355-
return arg
304+
if isinstance(arg, DelayedReturn):
305+
return _DelayedReturn(
306+
uri=get_step(arg)._properties.OutputDataConfig.S3OutputPath._pickleable,
307+
reference_path=arg._reference_path,
308+
)
356309

357-
converted_func_args = tuple(convert(arg) for arg in func_args)
358-
converted_func_kwargs = {key: convert(arg) for key, arg in func_kwargs.items()}
310+
if isinstance(arg, PipelineVariable):
311+
return arg._pickleable
359312

360-
return converted_func_args, converted_func_kwargs
313+
return arg

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def load_and_invoke(self) -> Any:
175175
args,
176176
kwargs,
177177
hmac_key=self.hmac_key,
178-
s3_base_uri=self.s3_base_uri,
179178
sagemaker_session=self.sagemaker_session,
180179
)
181180

src/sagemaker/remote_function/job.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@
5959
from sagemaker.s3 import s3_path_join, S3Uploader
6060
from sagemaker import vpc_utils
6161
from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData
62-
from sagemaker.remote_function.core.pipeline_variables import Context
62+
from sagemaker.remote_function.core.pipeline_variables import (
63+
Context,
64+
_convert_pipeline_variable_to_pickleable,
65+
)
6366
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
6467
RuntimeEnvironmentManager,
6568
_DependencySettings,
@@ -72,6 +75,7 @@
7275
copy_workdir,
7376
resolve_custom_file_filter_from_config_file,
7477
)
78+
from sagemaker.workflow.function_step import DelayedReturn
7579

7680
if TYPE_CHECKING:
7781
from sagemaker.workflow.entities import PipelineVariable
@@ -701,6 +705,7 @@ def compile(
701705
"""Build the artifacts and generate the training job request."""
702706
from sagemaker.workflow.properties import Properties
703707
from sagemaker.workflow.parameters import Parameter
708+
from sagemaker.workflow.functions import Join
704709
from sagemaker.workflow.execution_variables import ExecutionVariables, ExecutionVariable
705710
from sagemaker.workflow.utilities import load_step_compilation_context
706711

@@ -760,7 +765,19 @@ def compile(
760765
job_settings=job_settings, s3_base_uri=s3_base_uri
761766
)
762767

763-
output_config = {"S3OutputPath": s3_base_uri}
768+
if step_compilation_context:
769+
s3_output_path = Join(
770+
on="/",
771+
values=[
772+
s3_base_uri,
773+
ExecutionVariables.PIPELINE_EXECUTION_ID,
774+
step_compilation_context.step_name,
775+
"results",
776+
],
777+
)
778+
output_config = {"S3OutputPath": s3_output_path}
779+
else:
780+
output_config = {"S3OutputPath": s3_base_uri}
764781
if job_settings.s3_kms_key is not None:
765782
output_config["KmsKeyId"] = job_settings.s3_kms_key
766783
request_dict["OutputDataConfig"] = output_config
@@ -804,6 +821,11 @@ def compile(
804821
if isinstance(arg, (Parameter, ExecutionVariable, Properties)):
805822
container_args.extend([arg.expr["Get"], arg.to_string()])
806823

824+
if isinstance(arg, DelayedReturn):
825+
# The uri is a _Properties object
826+
uri = _convert_pipeline_variable_to_pickleable(arg).uri
827+
container_args.extend([uri.path, {"Get": uri.path}])
828+
807829
if run_info is not None:
808830
container_args.extend(["--run_in_context", json.dumps(dataclasses.asdict(run_info))])
809831
elif _RunContext.get_current_run() is not None:

src/sagemaker/workflow/function_step.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535

3636
from sagemaker.workflow.execution_variables import ExecutionVariables
37+
from sagemaker.workflow.properties import Properties
3738
from sagemaker.workflow.retry import RetryPolicy
3839
from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
3940
from sagemaker.workflow.step_collections import StepCollection
@@ -101,6 +102,12 @@ def __init__(
101102

102103
self.__job_settings = None
103104

105+
# It's for internal usage to retrieve execution id from the properties.
106+
# However, we won't expose the properties of function step to customers.
107+
self._properties = Properties(
108+
step_name=name, step=self, shape_name="DescribeTrainingJobResponse"
109+
)
110+
104111
(
105112
self._converted_func_args,
106113
self._converted_func_kwargs,

src/sagemaker/workflow/pipeline.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from sagemaker._studio import _append_project_tags
2929
from sagemaker.config import PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH
3030
from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3
31-
from sagemaker.remote_function.core.stored_function import RESULTS_FOLDER
3231
from sagemaker.remote_function.errors import RemoteFunctionError
3332
from sagemaker.remote_function.job import JOBS_CONTAINER_ENTRYPOINT
3433
from sagemaker.s3_utils import s3_path_join
@@ -977,23 +976,20 @@ def result(self, step_name: str):
977976
return get_function_step_result(
978977
step_name=step_name,
979978
step_list=self.list_steps(),
980-
execution_id=self.arn.split("/")[-1],
981979
sagemaker_session=self.sagemaker_session,
982980
)
983981

984982

985983
def get_function_step_result(
986984
step_name: str,
987985
step_list: list,
988-
execution_id: str,
989986
sagemaker_session: Session,
990987
):
991988
"""Helper function to retrieve the output of a ``@step`` decorated function.
992989
993990
Args:
994991
step_name (str): The name of the pipeline step.
995992
step_list (list): A list of executed pipeline steps of the specified execution.
996-
execution_id (str): The specified id of the pipeline execution.
997993
sagemaker_session (Session): Session object which manages interactions
998994
with Amazon SageMaker APIs and any other AWS services needed.
999995
Returns:
@@ -1043,7 +1039,7 @@ def get_function_step_result(
10431039
if job_status == "Completed":
10441040
return deserialize_obj_from_s3(
10451041
sagemaker_session=sagemaker_session,
1046-
s3_uri=s3_path_join(s3_output_path, execution_id, step_name, RESULTS_FOLDER),
1042+
s3_uri=s3_path_join(s3_output_path),
10471043
hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
10481044
)
10491045

tests/integ/sagemaker/workflow/helpers.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,24 @@ def create_and_execute_pipeline(
3939
step_result_type=None,
4040
step_result_value=None,
4141
wait_duration=400, # seconds
42+
selective_execution_config=None,
4243
):
43-
response = pipeline.create(role)
44-
45-
create_arn = response["PipelineArn"]
46-
assert re.match(
47-
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
48-
create_arn,
44+
create_arn = None
45+
if not selective_execution_config:
46+
response = pipeline.create(role)
47+
create_arn = response["PipelineArn"]
48+
assert re.match(
49+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
50+
create_arn,
51+
)
52+
53+
execution = pipeline.start(
54+
parameters=execution_parameters, selective_execution_config=selective_execution_config
4955
)
5056

51-
execution = pipeline.start(parameters=execution_parameters)
52-
response = execution.describe()
53-
assert response["PipelineArn"] == create_arn
57+
if create_arn:
58+
response = execution.describe()
59+
assert response["PipelineArn"] == create_arn
5460

5561
wait_pipeline_execution(execution=execution, delay=20, max_attempts=int(wait_duration / 20))
5662

@@ -71,6 +77,16 @@ def create_and_execute_pipeline(
7177
if step_result_value:
7278
result = execution.result(execution_steps[0]["StepName"])
7379
assert result == step_result_value, f"Expected {step_result_value}, instead found {result}"
80+
81+
if selective_execution_config:
82+
for exe_step in execution_steps:
83+
if exe_step["StepName"] in selective_execution_config.selected_steps:
84+
continue
85+
assert (
86+
exe_step["SelectiveExecutionResult"]["SourcePipelineExecutionArn"]
87+
== selective_execution_config.source_pipeline_execution_arn
88+
)
89+
7490
return execution, execution_steps
7591

7692

0 commit comments

Comments
 (0)