Skip to content

Commit 75e1d08

Browse files
committed
feature: method to build pipeline parameters from existing execution … (aws#951)
* feature: method to build pipeline parameters from existing execution with optional value overrides * fix style check * assert error message in unit test
1 parent b101b83 commit 75e1d08

File tree

2 files changed

+192
-1
lines changed

2 files changed

+192
-1
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,76 @@ def _get_latest_execution_arn(self):
463463
return response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]
464464
return None
465465

466+
def build_parameters_from_execution(
467+
self,
468+
pipeline_execution_arn: str,
469+
parameter_value_overrides: Dict[str, Union[str, bool, int, float]] = None,
470+
) -> Dict[str, Union[str, bool, int, float]]:
471+
"""Gets the parameters from an execution, update with optional parameter value overrides.
472+
473+
Args:
474+
pipeline_execution_arn (str): The arn of the pipeline execution.
475+
parameter_value_overrides (Dict[str, Union[str, bool, int, float]]): Parameter dict
476+
to be updated in the parameters from the referenced execution.
477+
478+
Returns:
479+
A parameter dict built from an execution and provided parameter value overrides.
480+
"""
481+
execution_parameters = self._get_parameters_for_execution(pipeline_execution_arn)
482+
if parameter_value_overrides is not None:
483+
self._validate_parameter_overrides(
484+
pipeline_execution_arn, execution_parameters, parameter_value_overrides
485+
)
486+
execution_parameters.update(parameter_value_overrides)
487+
return execution_parameters
488+
489+
def _get_parameters_for_execution(self, pipeline_execution_arn: str) -> Dict[str, str]:
490+
"""Gets all the parameters from an execution.
491+
492+
Args:
493+
pipeline_execution_arn (str): The arn of the pipeline execution.
494+
495+
Returns:
496+
A parameter dict from the execution.
497+
"""
498+
pipeline_execution = _PipelineExecution(
499+
arn=pipeline_execution_arn,
500+
sagemaker_session=self.sagemaker_session,
501+
)
502+
503+
response = pipeline_execution.list_parameters()
504+
parameter_list = response["PipelineParameters"]
505+
while response.get("NextToken") is not None:
506+
response = pipeline_execution.list_parameters(next_token=response["NextToken"])
507+
parameter_list.extend(response["PipelineParameters"])
508+
509+
return {parameter["Name"]: parameter["Value"] for parameter in parameter_list}
510+
511+
@staticmethod
512+
def _validate_parameter_overrides(
513+
pipeline_execution_arn: str,
514+
execution_parameters: Dict[str, str],
515+
parameter_overrides: Dict[str, Union[str, bool, int, float]],
516+
):
517+
"""Validates the parameter overrides are present in the execution parameters.
518+
519+
Args:
520+
pipeline_execution_arn (str): The arn of the pipeline execution.
521+
execution_parameters (Dict[str, str]): A parameter dict from the execution.
522+
parameter_overrides (Dict[str, Union[str, bool, int, float]]): Parameter dict to be
523+
updated in the parameters from the referenced execution.
524+
525+
Raises:
526+
ValueError: If any parameters in parameter overrides is not present in the
527+
execution parameters.
528+
"""
529+
invalid_parameters = set(parameter_overrides) - set(execution_parameters)
530+
if invalid_parameters:
531+
raise ValueError(
532+
f"The following parameter overrides provided: {str(invalid_parameters)} "
533+
+ f"are not present in the pipeline execution: {pipeline_execution_arn}"
534+
)
535+
466536

467537
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
468538
"""Formats start parameter overrides as a list of dicts.
@@ -652,6 +722,30 @@ def list_steps(self):
652722
)
653723
return response["PipelineExecutionSteps"]
654724

725+
def list_parameters(self, max_results: int = None, next_token: str = None):
726+
"""Gets a list of parameters for a pipeline execution.
727+
728+
Args:
729+
max_results (int): The maximum number of parameters to return in the response.
730+
next_token (str): If the result of the previous ListPipelineParametersForExecution
731+
request was truncated, the response includes a NextToken. To retrieve the next
732+
set of parameters, use the token in the next request.
733+
734+
Returns:
735+
Information about the parameters of the pipeline execution.
736+
See boto3 client list_pipeline_executions
737+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_pipeline_parameters_for_execution
738+
"""
739+
kwargs = dict(PipelineExecutionArn=self.arn)
740+
update_args(
741+
kwargs,
742+
MaxResults=max_results,
743+
NextToken=next_token,
744+
)
745+
return self.sagemaker_session.sagemaker_client.list_pipeline_parameters_for_execution(
746+
**kwargs
747+
)
748+
655749
def wait(self, delay=30, max_attempts=60):
656750
"""Waits for a pipeline execution.
657751

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from mock import Mock, patch
20+
from mock import Mock, call, patch
2121

2222
from sagemaker import s3
2323
from sagemaker.session_settings import SessionSettings
@@ -718,13 +718,99 @@ def test_pipeline_list_executions(sagemaker_session_mock):
718718
assert executions["NextToken"] == "token"
719719

720720

721+
def test_pipeline_build_parameters_from_execution(sagemaker_session_mock):
722+
pipeline = Pipeline(
723+
name="MyPipeline",
724+
sagemaker_session=sagemaker_session_mock,
725+
)
726+
reference_execution_arn = "reference_execution_arn"
727+
parameter_value_overrides = {"TestParameterName": "NewParameterValue"}
728+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
729+
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}]
730+
}
731+
parameters = pipeline.build_parameters_from_execution(
732+
pipeline_execution_arn=reference_execution_arn,
733+
parameter_value_overrides=parameter_value_overrides,
734+
)
735+
assert (
736+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
737+
PipelineExecutionArn=reference_execution_arn
738+
)
739+
)
740+
assert len(parameters) == 1
741+
assert parameters["TestParameterName"] == "NewParameterValue"
742+
743+
744+
def test_pipeline_build_parameters_from_execution_with_invalid_overrides(sagemaker_session_mock):
745+
pipeline = Pipeline(
746+
name="MyPipeline",
747+
sagemaker_session=sagemaker_session_mock,
748+
)
749+
reference_execution_arn = "reference_execution_arn"
750+
invalid_parameter_value_overrides = {"InvalidParameterName": "Value"}
751+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
752+
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}]
753+
}
754+
with pytest.raises(ValueError) as error:
755+
pipeline.build_parameters_from_execution(
756+
pipeline_execution_arn=reference_execution_arn,
757+
parameter_value_overrides=invalid_parameter_value_overrides,
758+
)
759+
assert (
760+
f"The following parameter overrides provided: {str(set(invalid_parameter_value_overrides.keys()))} "
761+
+ f"are not present in the pipeline execution: {reference_execution_arn}"
762+
in str(error)
763+
)
764+
assert (
765+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
766+
PipelineExecutionArn=reference_execution_arn
767+
)
768+
)
769+
770+
771+
def test_pipeline_build_parameters_from_execution_with_paginated_result(sagemaker_session_mock):
772+
pipeline = Pipeline(
773+
name="MyPipeline",
774+
sagemaker_session=sagemaker_session_mock,
775+
)
776+
reference_execution_arn = "reference_execution_arn"
777+
next_token = "token"
778+
first_page_response = {
779+
"PipelineParameters": [{"Name": "TestParameterName1", "Value": "TestParameterValue1"}],
780+
"NextToken": next_token,
781+
}
782+
second_page_response = {
783+
"PipelineParameters": [{"Name": "TestParameterName2", "Value": "TestParameterValue2"}],
784+
}
785+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.side_effect = [
786+
first_page_response,
787+
second_page_response,
788+
]
789+
parameters = pipeline.build_parameters_from_execution(
790+
pipeline_execution_arn=reference_execution_arn
791+
)
792+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_has_calls(
793+
[
794+
call(PipelineExecutionArn=reference_execution_arn),
795+
call(PipelineExecutionArn=reference_execution_arn, NextToken=next_token),
796+
]
797+
)
798+
assert len(parameters) == 2
799+
assert parameters["TestParameterName1"] == "TestParameterValue1"
800+
assert parameters["TestParameterName2"] == "TestParameterValue2"
801+
802+
721803
def test_pipeline_execution_basics(sagemaker_session_mock):
722804
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
723805
"PipelineExecutionArn": "my:arn"
724806
}
725807
sagemaker_session_mock.sagemaker_client.list_pipeline_execution_steps.return_value = {
726808
"PipelineExecutionSteps": [Mock()]
727809
}
810+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
811+
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}],
812+
"NextToken": "token",
813+
}
728814
pipeline = Pipeline(
729815
name="MyPipeline",
730816
parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")],
@@ -745,6 +831,17 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
745831
PipelineExecutionArn="my:arn"
746832
)
747833
assert len(steps) == 1
834+
list_parameters_response = execution.list_parameters()
835+
assert (
836+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
837+
PipelineExecutionArn="my:arn"
838+
)
839+
)
840+
parameter_list = list_parameters_response["PipelineParameters"]
841+
assert len(parameter_list) == 1
842+
assert parameter_list[0]["Name"] == "TestParameterName"
843+
assert parameter_list[0]["Value"] == "TestParameterValue"
844+
assert list_parameters_response["NextToken"] == "token"
748845

749846

750847
def _generate_large_pipeline_steps(input_data: object):

0 commit comments

Comments
 (0)