Skip to content

Commit 130c367

Browse files
authored
Merge branch 'master' into triton
2 parents f83f52e + 99afcc8 commit 130c367

File tree

9 files changed

+301
-46
lines changed

9 files changed

+301
-46
lines changed

src/sagemaker/local/entities.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def __init__(self, container):
201201
self.end_time = None
202202
self.environment = None
203203
self.training_job_name = ""
204+
self.output_data_config = None
204205

205206
def start(self, input_data_config, output_data_config, hyperparameters, environment, job_name):
206207
"""Starts a local training job.
@@ -248,6 +249,7 @@ def start(self, input_data_config, output_data_config, hyperparameters, environm
248249
self.end_time = datetime.datetime.now()
249250
self.state = self._COMPLETED
250251
self.training_job_name = job_name
252+
self.output_data_config = output_data_config
251253

252254
def describe(self):
253255
"""Placeholder docstring"""
@@ -259,6 +261,11 @@ def describe(self):
259261
"TrainingStartTime": self.start_time,
260262
"TrainingEndTime": self.end_time,
261263
"ModelArtifacts": {"S3ModelArtifacts": self.model_artifacts},
264+
"OutputDataConfig": self.output_data_config,
265+
"Environment": self.environment,
266+
"AlgorithmSpecification": {
267+
"ContainerEntrypoint": self.container.container_entrypoint,
268+
},
262269
}
263270
return response
264271

@@ -668,7 +675,12 @@ def start(self, **kwargs):
668675
from sagemaker.local.pipeline import LocalPipelineExecutor
669676

670677
execution_id = str(uuid4())
671-
execution = _LocalPipelineExecution(execution_id, self.pipeline, **kwargs)
678+
execution = _LocalPipelineExecution(
679+
execution_id=execution_id,
680+
pipeline=self.pipeline,
681+
local_session=self.local_session,
682+
**kwargs,
683+
)
672684

673685
self._executions[execution_id] = execution
674686
print(
@@ -689,13 +701,16 @@ def __init__(
689701
PipelineParameters=None,
690702
PipelineExecutionDescription=None,
691703
PipelineExecutionDisplayName=None,
704+
local_session=None,
692705
):
693706
from sagemaker.workflow.pipeline import PipelineGraph
707+
from sagemaker import LocalSession
694708

695709
self.pipeline = pipeline
696710
self.pipeline_execution_name = execution_id
697711
self.pipeline_execution_description = PipelineExecutionDescription
698712
self.pipeline_execution_display_name = PipelineExecutionDisplayName
713+
self.local_session = local_session or LocalSession()
699714
self.status = _LocalExecutionStatus.EXECUTING.value
700715
self.failure_reason = None
701716
self.creation_time = datetime.datetime.now().timestamp()
@@ -731,6 +746,27 @@ def list_steps(self):
731746
]
732747
}
733748

749+
def result(self, step_name: str):
750+
"""Retrieves the output of the provided step if it is a ``@step`` decorated function.
751+
752+
Args:
753+
step_name (str): The name of the pipeline step.
754+
Returns:
755+
The step output.
756+
757+
Raises:
758+
ValueError if the provided step is not a ``@step`` decorated function.
759+
RuntimeError if the provided step is not in "Completed" status.
760+
"""
761+
from sagemaker.workflow.pipeline import get_function_step_result
762+
763+
return get_function_step_result(
764+
step_name=step_name,
765+
step_list=self.list_steps()["PipelineExecutionSteps"],
766+
execution_id=self.pipeline_execution_name,
767+
sagemaker_session=self.local_session,
768+
)
769+
734770
def update_execution_success(self):
735771
"""Mark execution as succeeded."""
736772
self.status = _LocalExecutionStatus.SUCCEEDED.value

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def load_and_invoke(self) -> Any:
142142

143143
logger.info(
144144
"Serializing the function return and uploading to %s",
145-
s3_path_join(self.func_upload_path, RESULTS_FOLDER),
145+
s3_path_join(self.results_upload_path, RESULTS_FOLDER),
146146
)
147147
serialization.serialize_obj_to_s3(
148148
obj=result,

src/sagemaker/serve/utils/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ def __str__(self) -> str:
4040

4141
CPU = 1
4242
GPU = 2
43+
INFERENTIA_1 = 3
44+
INFERENTIA_2 = 4
45+
GRAVITON = 5

src/sagemaker/serve/validations/check_image_and_hardware_type.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@
1717
}
1818

1919

20+
INF1_INSTANCE_FAMILIES = {"ml.inf1"}
21+
INF2_INSTANCE_FAMILIES = {"ml.inf2"}
22+
23+
GRAVITON_INSTANCE_FAMILIES = {
24+
"ml.c7g",
25+
"ml.m6g",
26+
"ml.m6gd",
27+
"ml.c6g",
28+
"ml.c6gd",
29+
"ml.c6gn",
30+
"ml.r6g",
31+
"ml.r6gd",
32+
}
33+
34+
2035
def validate_image_uri_and_hardware(image_uri: str, instance_type: str, model_server: ModelServer):
2136
"""Placeholder docstring"""
2237
if "xgboost" in image_uri:
@@ -57,6 +72,12 @@ def detect_hardware_type_of_instance(instance_type: str) -> HardwareType:
5772
instance_family = instance_type.rsplit(".", 1)[0]
5873
if instance_family in GPU_INSTANCE_FAMILIES:
5974
return HardwareType.GPU
75+
if instance_family in INF1_INSTANCE_FAMILIES:
76+
return HardwareType.INFERENTIA_1
77+
if instance_family in INF2_INSTANCE_FAMILIES:
78+
return HardwareType.INFERENTIA_2
79+
if instance_family in GRAVITON_INSTANCE_FAMILIES:
80+
return HardwareType.GRAVITON
6081
return HardwareType.CPU
6182

6283

@@ -67,4 +88,13 @@ def detect_triton_image_hardware_type(image_uri: str) -> HardwareType:
6788

6889
def detect_torchserve_image_hardware_type(image_uri: str) -> HardwareType:
6990
"""Placeholder docstring"""
70-
return HardwareType.CPU if "cpu" in image_uri else HardwareType.GPU
91+
if "neuronx" in image_uri:
92+
return HardwareType.INFERENTIA_2
93+
if "neuron" in image_uri:
94+
return HardwareType.INFERENTIA_1
95+
if "graviton" in image_uri:
96+
return HardwareType.GRAVITON
97+
if "cpu" in image_uri:
98+
return HardwareType.CPU
99+
100+
return HardwareType.GPU

src/sagemaker/workflow/pipeline.py

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import pytz
2525
from botocore.exceptions import ClientError, WaiterError
2626

27-
from sagemaker import s3
27+
from sagemaker import s3, LocalSession
2828
from sagemaker._studio import _append_project_tags
2929
from sagemaker.config import PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH
3030
from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3
@@ -973,41 +973,81 @@ def result(self, step_name: str):
973973
except WaiterError as e:
974974
if "Waiter encountered a terminal failure state" not in str(e):
975975
raise
976-
step = next(filter(lambda x: x["StepName"] == step_name, self.list_steps()), None)
977-
if not step:
978-
raise ValueError(f"Invalid step name {step_name}")
979-
step_type = next(iter(step["Metadata"]))
980-
step_metadata = next(iter(step["Metadata"].values()))
981-
if step_type != "TrainingJob":
982-
raise ValueError(
983-
"This method can only be used on pipeline steps created using " "@step decorator."
984-
)
985976

986-
job_arn = step_metadata["Arn"]
987-
job_name = job_arn.split("/")[-1]
977+
return get_function_step_result(
978+
step_name=step_name,
979+
step_list=self.list_steps(),
980+
execution_id=self.arn.split("/")[-1],
981+
sagemaker_session=self.sagemaker_session,
982+
)
988983

989-
describe_training_job_response = self.sagemaker_session.describe_training_job(job_name)
990-
container_args = describe_training_job_response["AlgorithmSpecification"][
991-
"ContainerEntrypoint"
992-
]
993-
if container_args != JOBS_CONTAINER_ENTRYPOINT:
994-
raise ValueError(
995-
"This method can only be used on pipeline steps created using @step decorator."
996-
)
997-
s3_output_path = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]
998984

999-
job_status = describe_training_job_response["TrainingJobStatus"]
1000-
if job_status == "Completed":
1001-
return deserialize_obj_from_s3(
1002-
sagemaker_session=self.sagemaker_session,
1003-
s3_uri=s3_path_join(
1004-
s3_output_path, self.arn.split("/")[-1], step_name, RESULTS_FOLDER
1005-
),
1006-
hmac_key=describe_training_job_response["Environment"][
1007-
"REMOTE_FUNCTION_SECRET_KEY"
1008-
],
1009-
)
1010-
raise RemoteFunctionError(f"Pipeline step {step_name} is in {job_status} status.")
985+
def get_function_step_result(
986+
step_name: str,
987+
step_list: list,
988+
execution_id: str,
989+
sagemaker_session: Session,
990+
):
991+
"""Helper function to retrieve the output of a ``@step`` decorated function.
992+
993+
Args:
994+
step_name (str): The name of the pipeline step.
995+
step_list (list): A list of executed pipeline steps of the specified execution.
996+
execution_id (str): The specified id of the pipeline execution.
997+
sagemaker_session (Session): Session object which manages interactions
998+
with Amazon SageMaker APIs and any other AWS services needed.
999+
Returns:
1000+
The step output.
1001+
1002+
Raises:
1003+
ValueError if the provided step is not a ``@step`` decorated function.
1004+
RemoteFunctionError if the provided step is not in "Completed" status
1005+
"""
1006+
_ERROR_MSG_OF_WRONG_STEP_TYPE = (
1007+
"This method can only be used on pipeline steps created using @step decorator."
1008+
)
1009+
_ERROR_MSG_OF_STEP_INCOMPLETE = (
1010+
f"Unable to retrieve step output as the step {step_name} is not in Completed status."
1011+
)
1012+
1013+
step = next(filter(lambda x: x["StepName"] == step_name, step_list), None)
1014+
if not step:
1015+
raise ValueError(f"Invalid step name {step_name}")
1016+
1017+
if isinstance(sagemaker_session, LocalSession) and not step.get("Metadata", None):
1018+
# In local mode, if the training job failed,
1019+
# it's not tracked in LocalSagemakerClient and it's not describable.
1020+
# Thus, the step Metadata is not set.
1021+
raise RuntimeError(_ERROR_MSG_OF_STEP_INCOMPLETE)
1022+
1023+
step_type = next(iter(step["Metadata"]))
1024+
step_metadata = next(iter(step["Metadata"].values()))
1025+
if step_type != "TrainingJob":
1026+
raise ValueError(_ERROR_MSG_OF_WRONG_STEP_TYPE)
1027+
1028+
job_arn = step_metadata["Arn"]
1029+
job_name = job_arn.split("/")[-1]
1030+
1031+
if isinstance(sagemaker_session, LocalSession):
1032+
describe_training_job_response = sagemaker_session.sagemaker_client.describe_training_job(
1033+
job_name
1034+
)
1035+
else:
1036+
describe_training_job_response = sagemaker_session.describe_training_job(job_name)
1037+
container_args = describe_training_job_response["AlgorithmSpecification"]["ContainerEntrypoint"]
1038+
if container_args != JOBS_CONTAINER_ENTRYPOINT:
1039+
raise ValueError(_ERROR_MSG_OF_WRONG_STEP_TYPE)
1040+
s3_output_path = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]
1041+
1042+
job_status = describe_training_job_response["TrainingJobStatus"]
1043+
if job_status == "Completed":
1044+
return deserialize_obj_from_s3(
1045+
sagemaker_session=sagemaker_session,
1046+
s3_uri=s3_path_join(s3_output_path, execution_id, step_name, RESULTS_FOLDER),
1047+
hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
1048+
)
1049+
1050+
raise RemoteFunctionError(_ERROR_MSG_OF_STEP_INCOMPLETE)
10111051

10121052

10131053
class PipelineGraph:

tests/integ/sagemaker/workflow/test_step_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def divide(x, y):
592592
step_name = execution_steps[0]["StepName"]
593593
with pytest.raises(RemoteFunctionError) as e:
594594
execution.result(step_name)
595-
assert f"Pipeline step {step_name} is in Failed status." in str(e)
595+
assert f"step {step_name} is not in Completed status." in str(e)
596596
finally:
597597
try:
598598
pipeline.delete()

tests/integ/test_local_mode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import stopit
2525

2626
import tests.integ.lock as lock
27+
from sagemaker.workflow.step_outputs import get_step
2728
from tests.integ.sagemaker.conftest import _build_container, DOCKERFILE_TEMPLATE
2829
from sagemaker.config import SESSION_DEFAULT_S3_BUCKET_PATH
2930
from sagemaker.utils import resolve_value_from_config
@@ -769,7 +770,7 @@ def test_local_pipeline_with_step_decorator_and_step_dependency(
769770
)
770771

771772
@step(**step_settings)
772-
def generator() -> tuple:
773+
def generator():
773774
return 3, 4
774775

775776
@step(**step_settings)
@@ -798,6 +799,9 @@ def sum(a, b):
798799
pipeline_execution_list_steps_result = execution.list_steps()
799800
assert len(pipeline_execution_list_steps_result["PipelineExecutionSteps"]) == 2
800801

802+
assert execution.result(step_name=get_step(step_output_a).name) == (3, 4)
803+
assert execution.result(step_name=get_step(step_output_b).name) == 7
804+
801805

802806
@pytest.mark.local_mode
803807
def test_local_pipeline_with_step_decorator_and_pre_exe_script(

0 commit comments

Comments
 (0)