Skip to content

Commit 091f6c6

Browse files
Ao GuoNamrata Madan
authored andcommitted
Add error handling for remote function.
1 parent ecd5008 commit 091f6c6

File tree

7 files changed

+560
-33
lines changed

7 files changed

+560
-33
lines changed

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ Jinja2==3.0.3
2222
pandas>=1.3.5,<1.5
2323
scikit-learn==1.0.2
2424
cloudpickle==2.2.0
25+
tblib==1.7.0

src/sagemaker/remote_function/client.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import inspect
2323

2424
from botocore.exceptions import ClientError
25+
from sagemaker.exceptions import UnexpectedStatusException
2526

2627
import sagemaker.remote_function.core.serialization as serialization
28+
from sagemaker.remote_function.errors import RemoteFunctionError, ServiceError
2729

2830
from sagemaker.session import Session
2931
from sagemaker.s3 import s3_path_join
@@ -136,18 +138,47 @@ def wrapper(*args, **kwargs):
136138
volume_size=volume_size,
137139
)
138140
job = _Job.start(job_settings, func, args, kwargs)
139-
job.wait()
141+
142+
try:
143+
job.wait()
144+
except UnexpectedStatusException as usex:
145+
if usex.actual_status == "Failed":
146+
try:
147+
exception = serialization.deserialize_obj_from_s3(
148+
sagemaker_session=job_settings.sagemaker_session,
149+
s3_uri=s3_path_join(
150+
job_settings.s3_root_uri, job.job_name, "exception.pkl"
151+
),
152+
)
153+
except ServiceError as serr:
154+
chained_e = serr.__cause__
155+
if (
156+
isinstance(chained_e, ClientError)
157+
and chained_e.response["Error"]["Code"] # pylint: disable=no-member
158+
== "404"
159+
and chained_e.response["Error"]["Message"] # pylint: disable=no-member
160+
== "Not Found"
161+
):
162+
raise RemoteFunctionError(
163+
"Failed to execute remote function. "
164+
+ "Check corresponding job for details."
165+
)
166+
raise serr
167+
168+
raise exception
169+
170+
raise TimeoutError(
171+
"Job for remote function timed out before reaching a termination status."
172+
)
140173

141174
if job.describe()["TrainingJobStatus"] == "Completed":
142175
return serialization.deserialize_obj_from_s3(
143176
sagemaker_session=job_settings.sagemaker_session,
144177
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, "results.pkl"),
145178
)
146179

147-
if job.describe()["TrainingJobStatus"] == "Failed":
148-
# TODO: the exception should be constructed based on the job failure reason
149-
# and function output
150-
raise RuntimeError("Job failed")
180+
if job.describe()["TrainingJobStatus"] == "Stopped":
181+
raise RemoteFunctionError("Job for remote function has been aborted.")
151182

152183
return None
153184

@@ -522,7 +553,10 @@ def result(self, timeout: int = None) -> Any:
522553
Returns:
523554
The Python object returned by the function
524555
"""
525-
self.wait(timeout)
556+
try:
557+
self.wait(timeout)
558+
except UnexpectedStatusException:
559+
pass
526560

527561
with self._condition:
528562
if self._state == _PENDING:
@@ -539,13 +573,36 @@ def result(self, timeout: int = None) -> Any:
539573
self._state = _FINISHED
540574
return self._return
541575
if self._job.describe()["TrainingJobStatus"] == "Failed":
542-
# TODO: the exception should be constructed based on the job failure reason
543-
# and function output
544-
self._exception = RuntimeError()
576+
try:
577+
self._exception = serialization.deserialize_obj_from_s3(
578+
sagemaker_session=self._job_settings.sagemaker_session,
579+
s3_uri=s3_path_join(
580+
self._job_settings.s3_root_uri, self._job.job_name, "exception.pkl"
581+
),
582+
)
583+
except ServiceError as serr:
584+
chained_e = serr.__cause__
585+
if (
586+
isinstance(chained_e, ClientError)
587+
and chained_e.response["Error"]["Code"] # pylint: disable=no-member
588+
== "404"
589+
and chained_e.response["Error"]["Message"] # pylint: disable=no-member
590+
== "Not Found"
591+
):
592+
self._exception = RemoteFunctionError(
593+
"Failed to execute remote function. "
594+
+ "Check corresponding job for details."
595+
)
596+
else:
597+
self._exception = serr
545598
self._state = _FINISHED
546-
raise self._exception
547-
548-
raise RuntimeError()
599+
elif self._job.describe()["TrainingJobStatus"] == "Stopped":
600+
self._state = _CANCELLED
601+
raise RemoteFunctionError("Job for remote function has been aborted.")
602+
else:
603+
raise TimeoutError(
604+
"Job for remote function timed out before reaching a termination status."
605+
)
549606

550607
if self._state == _FINISHED:
551608
if self._exception:

src/sagemaker/remote_function/core/serialization.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import cloudpickle
1717

1818
from typing import Any, Callable
19+
from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError
1920
from sagemaker.s3 import S3Downloader, S3Uploader
2021

2122

@@ -32,10 +33,19 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
3233
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
3334
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
3435
func: function to be serialized and persisted
36+
Raises:
37+
SerializationError: when fail to serialize function to bytes.
3538
"""
36-
S3Uploader.upload_bytes(
37-
cloudpickle.dumps(func), s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session
38-
)
39+
try:
40+
bytes_to_upload = cloudpickle.dumps(func)
41+
except Exception as e:
42+
raise SerializationError(
43+
"Error when serializing function [{}]: {}".format(
44+
getattr(func, "__name__", repr(func)), e
45+
)
46+
) from e
47+
48+
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
3949

4050

4151
def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
@@ -50,8 +60,17 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
5060
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
5161
Returns :
5262
The deserialized function.
63+
Raises:
64+
DeserializationError: when fail to serialize function to bytes.
5365
"""
54-
return cloudpickle.loads(S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session))
66+
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)
67+
68+
try:
69+
return cloudpickle.loads(bytes_to_deserialize)
70+
except Exception as e:
71+
raise DeserializationError(
72+
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, e)
73+
) from e
5574

5675

5776
def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
@@ -63,10 +82,17 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
6382
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
6483
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
6584
obj: object to be serialized and persisted
85+
Raises:
86+
SerializationError: when fail to serialize object to bytes.
6687
"""
67-
S3Uploader.upload_bytes(
68-
cloudpickle.dumps(obj), s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session
69-
)
88+
try:
89+
bytes_to_upload = cloudpickle.dumps(obj)
90+
except Exception as e:
91+
raise SerializationError(
92+
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, e)
93+
) from e
94+
95+
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
7096

7197

7298
def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
@@ -78,5 +104,32 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
78104
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
79105
Returns :
80106
List of deserialized python objects.
107+
Raises:
108+
DeserializationError: when fail to serialize object to bytes.
81109
"""
82-
return cloudpickle.loads(S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session))
110+
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)
111+
112+
try:
113+
return cloudpickle.loads(bytes_to_deserialize)
114+
except Exception as e:
115+
raise DeserializationError(
116+
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, e)
117+
) from e
118+
119+
120+
def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):
121+
"""Wrapping s3 uploading with exception translation for remote function."""
122+
try:
123+
S3Uploader.upload_bytes(
124+
bytes, s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session
125+
)
126+
except Exception as e:
127+
raise ServiceError("Failed to upload serialized bytes to {}: {}".format(s3_uri, e)) from e
128+
129+
130+
def _read_bytes_from_s3(s3_uri, sagemaker_session):
131+
"""Wrapping s3 downloading with exception translation for remote function."""
132+
try:
133+
return S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session)
134+
except Exception as e:
135+
raise ServiceError("Failed to read serialized bytes from {}: {}".format(s3_uri, e)) from e

src/sagemaker/remote_function/errors.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
FAILURE_REASON_PATH = "/opt/ml/output/failure"
2525

2626

27+
@pickling_support.install
2728
class RemoteFunctionError(Exception):
2829
"""The base exception class for remote function excepitons"""
2930

@@ -32,18 +33,22 @@ def __init__(self, message):
3233
super().__init__(self.message)
3334

3435

36+
@pickling_support.install
3537
class ServiceError(RemoteFunctionError):
3638
"""Raised when errors encountered during interaction with SageMaker, S3 service APIs"""
3739

3840

41+
@pickling_support.install
3942
class RuntimeEnvironmentError(RemoteFunctionError):
4043
"""Raised when errors encountered during remote function runtime environment setup"""
4144

4245

46+
@pickling_support.install
4347
class SerializationError(RemoteFunctionError):
4448
"""Raised when errors encountered during serialization of remote function objects"""
4549

4650

51+
@pickling_support.install
4752
class DeserializationError(RemoteFunctionError):
4853
"""Raised when errors encountered during deserialization of remote function objects"""
4954

@@ -80,7 +85,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int:
8085
s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded.
8186
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
8287
Returns :
83-
List of deserialized python objects.
88+
exit_code (int): Exit code to terminate current job.
8489
"""
8590
pickling_support.install()
8691
failure_reason = repr(error)

src/sagemaker/remote_function/job_driver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626

2727
SUCCESS_EXIT_CODE = 0
28+
DEFAULT_FAILURE_CODE = 1
2829

2930

3031
def _parse_agrs():
@@ -85,7 +86,7 @@ def main():
8586
logging_config.basic_config()
8687
logger = logging_config.get_logger()
8788

88-
exit_code = SUCCESS_EXIT_CODE
89+
exit_code = DEFAULT_FAILURE_CODE
8990
try:
9091
args = _parse_agrs()
9192
region = args.region
@@ -109,7 +110,7 @@ def main():
109110
)
110111

111112
_execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key)
112-
113+
exit_code = SUCCESS_EXIT_CODE
113114
except Exception as e: # pylint: disable=broad-except
114115
logger.exception("Error encountered when invoking the remote function.")
115116
exit_code = handle_error(e, sagemaker_session, s3_base_uri, s3_kms_key)

0 commit comments

Comments
 (0)