|
24 | 24 | import pytz
|
25 | 25 | from botocore.exceptions import ClientError, WaiterError
|
26 | 26 |
|
27 |
| -from sagemaker import s3 |
| 27 | +from sagemaker import s3, LocalSession |
28 | 28 | from sagemaker._studio import _append_project_tags
|
29 | 29 | from sagemaker.config import PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH
|
30 | 30 | from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3
|
@@ -973,41 +973,81 @@ def result(self, step_name: str):
|
973 | 973 | except WaiterError as e:
|
974 | 974 | if "Waiter encountered a terminal failure state" not in str(e):
|
975 | 975 | raise
|
976 |
| - step = next(filter(lambda x: x["StepName"] == step_name, self.list_steps()), None) |
977 |
| - if not step: |
978 |
| - raise ValueError(f"Invalid step name {step_name}") |
979 |
| - step_type = next(iter(step["Metadata"])) |
980 |
| - step_metadata = next(iter(step["Metadata"].values())) |
981 |
| - if step_type != "TrainingJob": |
982 |
| - raise ValueError( |
983 |
| - "This method can only be used on pipeline steps created using " "@step decorator." |
984 |
| - ) |
985 | 976 |
|
986 |
| - job_arn = step_metadata["Arn"] |
987 |
| - job_name = job_arn.split("/")[-1] |
| 977 | + return get_function_step_result( |
| 978 | + step_name=step_name, |
| 979 | + step_list=self.list_steps(), |
| 980 | + execution_id=self.arn.split("/")[-1], |
| 981 | + sagemaker_session=self.sagemaker_session, |
| 982 | + ) |
988 | 983 |
|
989 |
| - describe_training_job_response = self.sagemaker_session.describe_training_job(job_name) |
990 |
| - container_args = describe_training_job_response["AlgorithmSpecification"][ |
991 |
| - "ContainerEntrypoint" |
992 |
| - ] |
993 |
| - if container_args != JOBS_CONTAINER_ENTRYPOINT: |
994 |
| - raise ValueError( |
995 |
| - "This method can only be used on pipeline steps created using @step decorator." |
996 |
| - ) |
997 |
| - s3_output_path = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] |
998 | 984 |
|
999 |
| - job_status = describe_training_job_response["TrainingJobStatus"] |
1000 |
| - if job_status == "Completed": |
1001 |
| - return deserialize_obj_from_s3( |
1002 |
| - sagemaker_session=self.sagemaker_session, |
1003 |
| - s3_uri=s3_path_join( |
1004 |
| - s3_output_path, self.arn.split("/")[-1], step_name, RESULTS_FOLDER |
1005 |
| - ), |
1006 |
| - hmac_key=describe_training_job_response["Environment"][ |
1007 |
| - "REMOTE_FUNCTION_SECRET_KEY" |
1008 |
| - ], |
1009 |
| - ) |
1010 |
| - raise RemoteFunctionError(f"Pipeline step {step_name} is in {job_status} status.") |
| 985 | +def get_function_step_result( |
| 986 | + step_name: str, |
| 987 | + step_list: list, |
| 988 | + execution_id: str, |
| 989 | + sagemaker_session: Session, |
| 990 | +): |
| 991 | + """Helper function to retrieve the output of a ``@step`` decorated function. |
| 992 | +
|
| 993 | + Args: |
| 994 | + step_name (str): The name of the pipeline step. |
| 995 | + step_list (list): A list of executed pipeline steps of the specified execution. |
| 996 | + execution_id (str): The specified id of the pipeline execution. |
| 997 | + sagemaker_session (Session): Session object which manages interactions |
| 998 | + with Amazon SageMaker APIs and any other AWS services needed. |
| 999 | + Returns: |
| 1000 | + The step output. |
| 1001 | +
|
| 1002 | + Raises: |
| 1003 | + ValueError if the provided step is not a ``@step`` decorated function. |
| 1004 | + RemoteFunctionError if the provided step is not in "Completed" status |
| 1005 | + """ |
| 1006 | + _ERROR_MSG_OF_WRONG_STEP_TYPE = ( |
| 1007 | + "This method can only be used on pipeline steps created using @step decorator." |
| 1008 | + ) |
| 1009 | + _ERROR_MSG_OF_STEP_INCOMPLETE = ( |
| 1010 | + f"Unable to retrieve step output as the step {step_name} is not in Completed status." |
| 1011 | + ) |
| 1012 | + |
| 1013 | + step = next(filter(lambda x: x["StepName"] == step_name, step_list), None) |
| 1014 | + if not step: |
| 1015 | + raise ValueError(f"Invalid step name {step_name}") |
| 1016 | + |
| 1017 | + if isinstance(sagemaker_session, LocalSession) and not step.get("Metadata", None): |
| 1018 | + # In local mode, if the training job failed, |
| 1019 | + # it's not tracked in LocalSagemakerClient and it's not describable. |
| 1020 | + # Thus, the step Metadata is not set. |
| 1021 | + raise RuntimeError(_ERROR_MSG_OF_STEP_INCOMPLETE) |
| 1022 | + |
| 1023 | + step_type = next(iter(step["Metadata"])) |
| 1024 | + step_metadata = next(iter(step["Metadata"].values())) |
| 1025 | + if step_type != "TrainingJob": |
| 1026 | + raise ValueError(_ERROR_MSG_OF_WRONG_STEP_TYPE) |
| 1027 | + |
| 1028 | + job_arn = step_metadata["Arn"] |
| 1029 | + job_name = job_arn.split("/")[-1] |
| 1030 | + |
| 1031 | + if isinstance(sagemaker_session, LocalSession): |
| 1032 | + describe_training_job_response = sagemaker_session.sagemaker_client.describe_training_job( |
| 1033 | + job_name |
| 1034 | + ) |
| 1035 | + else: |
| 1036 | + describe_training_job_response = sagemaker_session.describe_training_job(job_name) |
| 1037 | + container_args = describe_training_job_response["AlgorithmSpecification"]["ContainerEntrypoint"] |
| 1038 | + if container_args != JOBS_CONTAINER_ENTRYPOINT: |
| 1039 | + raise ValueError(_ERROR_MSG_OF_WRONG_STEP_TYPE) |
| 1040 | + s3_output_path = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] |
| 1041 | + |
| 1042 | + job_status = describe_training_job_response["TrainingJobStatus"] |
| 1043 | + if job_status == "Completed": |
| 1044 | + return deserialize_obj_from_s3( |
| 1045 | + sagemaker_session=sagemaker_session, |
| 1046 | + s3_uri=s3_path_join(s3_output_path, execution_id, step_name, RESULTS_FOLDER), |
| 1047 | + hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"], |
| 1048 | + ) |
| 1049 | + |
| 1050 | + raise RemoteFunctionError(_ERROR_MSG_OF_STEP_INCOMPLETE) |
1011 | 1051 |
|
1012 | 1052 |
|
1013 | 1053 | class PipelineGraph:
|
|
0 commit comments