Skip to content

feature: handle use case where endpoint is created outside of python … #3867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/sagemaker/async_inference/async_inference_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,31 @@ def get_result(
return self._result

def _get_result_from_s3(self, output_path, failure_path):
"""Retrieve output based on the presense of failure_path"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm - synced with Praneeth offline as well; we'll track a backlog item to go back to SageMaker Python SDK again and add Integ tests for edge case where customers created Async endpoint without failure path and loaded in Python SDK for invocation usage.

if failure_path is not None:
return self._get_result_from_s3_output_failure_paths(output_path, failure_path)

return self._get_result_from_s3_output_path(output_path)

def _get_result_from_s3_output_path(self, output_path):
"""Get inference result from the output Amazon S3 path"""
bucket, key = parse_s3_url(output_path)
try:
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
return self.predictor_async.predictor._handle_response(response)
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise ObjectNotExistedError(
message="Inference could still be running",
output_path=output_path,
)
raise UnexpectedClientError(
message=ex.response["Error"]["Message"],
)

def _get_result_from_s3_output_failure_paths(self, output_path, failure_path):
"""Get inference result from the output & failure Amazon S3 path"""
bucket, key = parse_s3_url(output_path)
try:
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
return self.predictor_async.predictor._handle_response(response)
Expand Down
31 changes: 29 additions & 2 deletions src/sagemaker/predictor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def predict(
self._input_path = input_path
response = self._submit_async_request(input_path, initial_args, inference_id)
output_location = response["OutputLocation"]
failure_location = response["FailureLocation"]
failure_location = response.get("FailureLocation")
result = self._wait_for_output(
output_path=output_location, failure_path=failure_location, waiter_config=waiter_config
)
Expand Down Expand Up @@ -145,7 +145,7 @@ def predict_async(
self._input_path = input_path
response = self._submit_async_request(input_path, initial_args, inference_id)
output_location = response["OutputLocation"]
failure_location = response["FailureLocation"]
failure_location = response.get("FailureLocation")
response_async = AsyncInferenceResponse(
predictor_async=self,
output_path=output_location,
Expand Down Expand Up @@ -216,6 +216,33 @@ def _submit_async_request(
return response

def _wait_for_output(self, output_path, failure_path, waiter_config):
"""Retrieve output based on the presense of failure_path."""
if failure_path is not None:
return self._check_output_and_failure_paths(output_path, failure_path, waiter_config)

return self._check_output_path(output_path, waiter_config)

def _check_output_path(self, output_path, waiter_config):
"""Check the Amazon S3 output path for the output.

Periodically check Amazon S3 output path for async inference result.
Timeout automatically after max attempts reached
"""
bucket, key = parse_s3_url(output_path)
s3_waiter = self.s3_client.get_waiter("object_exists")
try:
s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict())
except WaiterError:
raise PollingTimeoutError(
message="Inference could still be running",
output_path=output_path,
seconds=waiter_config.delay * waiter_config.max_attempts,
)
s3_object = self.s3_client.get_object(Bucket=bucket, Key=key)
result = self.predictor._handle_response(response=s3_object)
return result

def _check_output_and_failure_paths(self, output_path, failure_path, waiter_config):
"""Check the Amazon S3 output path for the output.

This method waits for either the output file or the failure file to be found on the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ def empty_s3_client():
return s3_client


def empty_s3_client_to_verify_exceptions_for_null_failure_path():
"""
Returns a mocked S3 client with the `get_object` method overridden
to raise different exceptions based on the input.

Exceptions raised:
- `ObjectNotExistedError`
- `UnexpectedClientError`

"""
s3_client = Mock(name="s3-client")

object_error = ObjectNotExistedError("Inference could still be running", DEFAULT_OUTPUT_PATH)

unexpected_error = UnexpectedClientError("some error message")

s3_client.get_object = Mock(
name="get_object",
side_effect=[
object_error,
unexpected_error,
],
)
return s3_client


def mock_s3_client():
"""
This function returns a mocked S3 client object that has a get_object method with a side_effect
Expand Down Expand Up @@ -172,3 +198,47 @@ def test_get_result_verify_exceptions():
UnexpectedClientError, match="Encountered unexpected client error: some error message"
):
async_inference_response.get_result()


def test_get_result_with_null_failure_path():
"""
verifies that the result is returned correctly if no errors occur.
"""
# Initialize AsyncInferenceResponse
predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
predictor_async.s3_client = mock_s3_client()
async_inference_response = AsyncInferenceResponse(
output_path=DEFAULT_OUTPUT_PATH, predictor_async=predictor_async, failure_path=None
)

result = async_inference_response.get_result()
assert async_inference_response._result == result
assert result == RETURN_VALUE


def test_get_result_verify_exceptions_with_null_failure_path():
"""
Verifies that get_result method raises the expected exception
when an error occurs while fetching the result.
"""
# Initialize AsyncInferenceResponse
predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
predictor_async.s3_client = empty_s3_client_to_verify_exceptions_for_null_failure_path()
async_inference_response = AsyncInferenceResponse(
output_path=DEFAULT_OUTPUT_PATH,
predictor_async=predictor_async,
failure_path=None,
)

# Test ObjectNotExistedError
with pytest.raises(
ObjectNotExistedError,
match=f"Object not exist at {DEFAULT_OUTPUT_PATH}. Inference could still be running",
):
async_inference_response.get_result()

# Test UnexpectedClientError
with pytest.raises(
UnexpectedClientError, match="Encountered unexpected client error: some error message"
):
async_inference_response.get_result()
114 changes: 113 additions & 1 deletion tests/unit/test_predictor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,37 @@ def empty_sagemaker_session():
return ims


def empty_sagemaker_session_with_null_failure_path():
ims = Mock(name="sagemaker_session")
ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime")
ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)

ims.sagemaker_runtime_client.invoke_endpoint_async = Mock(
name="invoke_endpoint_async",
return_value={
"OutputLocation": ASYNC_OUTPUT_LOCATION,
},
)

polling_timeout_error = PollingTimeoutError(
message="Inference could still be running",
output_path=ASYNC_OUTPUT_LOCATION,
seconds=DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts,
)

ims.s3_client = Mock(name="s3_client")
ims.s3_client.get_object = Mock(
name="get_object",
side_effect=[polling_timeout_error],
)

ims.s3_client.put_object = Mock(name="put_object")

return ims


def empty_predictor():
predictor = Mock(name="predictor")
predictor.update_endpoint = Mock(name="update_endpoint")
Expand Down Expand Up @@ -161,6 +192,31 @@ def test_async_predict_call_with_data_and_input_path():
assert result.failure_path == ASYNC_FAILURE_LOCATION


def test_async_predict_call_with_data_and_input_and_null_failure_path():
sagemaker_session = empty_sagemaker_session_with_null_failure_path()
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
predictor_async.name = ASYNC_PREDICTOR
data = DUMMY_DATA

result = predictor_async.predict_async(data=data, input_path=ASYNC_INPUT_LOCATION)
assert sagemaker_session.s3_client.put_object.called

assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called

expected_request_args = {
"Accept": DEFAULT_ACCEPT,
"InputLocation": ASYNC_INPUT_LOCATION,
"EndpointName": ENDPOINT,
}

call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.call_args
assert kwargs == expected_request_args
assert result.output_path == ASYNC_OUTPUT_LOCATION
assert result.failure_path is None


def test_async_predict_call_verify_exceptions():
sagemaker_session = empty_sagemaker_session()
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
Expand All @@ -185,7 +241,27 @@ def test_async_predict_call_verify_exceptions():
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called


def test_async_predict_call_pass_through_success():
def test_async_predict_call_verify_exceptions_with_null_failure_path():
sagemaker_session = empty_sagemaker_session_with_null_failure_path()
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))

input_location = "s3://some-input-path"

with pytest.raises(
PollingTimeoutError,
match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for "
f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}"
f" seconds. Inference could still be running",
):
predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG)

assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
assert sagemaker_session.s3_client.get_object.called
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called


def test_async_predict_call_pass_through_output_failure_paths():
sagemaker_session = empty_sagemaker_session()

response_body = Mock("body")
Expand Down Expand Up @@ -222,6 +298,42 @@ def test_async_predict_call_pass_through_success():
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called


def test_async_predict_call_pass_through_with_null_failure_path():
sagemaker_session = empty_sagemaker_session_with_null_failure_path()

response_body = Mock("body")
response_body.read = Mock("read", return_value=RETURN_VALUE)
response_body.close = Mock("close", return_value=None)

sagemaker_session.s3_client = Mock(name="s3_client")
sagemaker_session.s3_client.get_object = Mock(
name="get_object",
return_value={"Body": response_body},
)
sagemaker_session.s3_client.put_object = Mock(name="put_object")

predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))

sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async = Mock(
name="invoke_endpoint_async",
return_value={
"OutputLocation": ASYNC_OUTPUT_LOCATION,
},
)

input_location = "s3://some-input-path"

result = predictor_async.predict(
input_path=input_location,
)

assert result == RETURN_VALUE
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
assert sagemaker_session.s3_client.get_waiter.called_with("object_exists")
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called


def test_predict_async_call_invalid_input():
sagemaker_session = empty_sagemaker_session()
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
Expand Down