Skip to content

fix: mitigation of xgboost container incompatibility with new version #4298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def _build_for_torchserve(self) -> Type[Model]:
shared_libs=self.shared_libs,
dependencies=self.dependencies,
session=self.sagemaker_session,
image_uri=self.image_uri,
inference_spec=self.inference_spec,
)

Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/serve/detector/dependency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool =

with open(path, "r") as f:
autodetect_depedencies = f.read().splitlines()
autodetect_depedencies.append("sagemaker")
autodetect_depedencies.append("sagemaker>=2.199")
else:
autodetect_depedencies = ["sagemaker"]
autodetect_depedencies = ["sagemaker>=2.199"]

module_version_dict = _parse_dependency_list(autodetect_depedencies)

Expand Down
14 changes: 11 additions & 3 deletions src/sagemaker/serve/model_server/torchserve/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from __future__ import absolute_import
import os
from pathlib import Path
import shutil
from typing import List
Expand All @@ -15,7 +16,7 @@
generate_secret_key,
compute_hash,
)

from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
from sagemaker.remote_function.core.serialization import _MetaData


Expand All @@ -24,6 +25,7 @@ def prepare_for_torchserve(
shared_libs: List[str],
dependencies: dict,
session: Session,
image_uri: str,
inference_spec: InferenceSpec = None,
) -> str:
"""This is a one-line summary of the function.
Expand Down Expand Up @@ -51,8 +53,14 @@ def prepare_for_torchserve(

code_dir = model_path.joinpath("code")
code_dir.mkdir(exist_ok=True)

shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)
# https://github.com/aws/sagemaker-python-sdk/issues/4288
if is_1p_image_uri(image_uri=image_uri) and "xgboost" in image_uri:
shutil.copy2(Path(__file__).parent.joinpath("xgboost_inference.py"), code_dir)
os.rename(
str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py"))
)
else:
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)

shared_libs_dir = model_path.joinpath("shared_libs")
shared_libs_dir.mkdir(exist_ok=True)
Expand Down
136 changes: 136 additions & 0 deletions src/sagemaker/serve/model_server/torchserve/xgboost_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""This module is for SageMaker inference.py."""
from __future__ import absolute_import
import os
import io
import subprocess
import cloudpickle
import shutil
import platform
from pathlib import Path
from functools import partial
import logging

logger = logging.getLogger(__name__)

inference_spec = None
native_model = None
schema_builder = None


def model_fn(model_dir):
"""Placeholder docstring"""
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.image_detector import (
_detect_framework_and_version,
_get_model_base,
)
from sagemaker.serve.detector.pickler import load_xgboost_from_json

shared_libs_path = Path(model_dir + "/shared_libs")

if shared_libs_path.exists():
# before importing, place dynamic linked libraries in shared lib path
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)

serve_path = Path(__file__).parent.joinpath("serve.pkl")
with open(str(serve_path), mode="rb") as file:
global inference_spec, native_model, schema_builder
obj = cloudpickle.load(file)
if isinstance(obj[0], InferenceSpec):
inference_spec, schema_builder = obj
elif isinstance(obj[0], str) and obj[0] == "xgboost":
model_class_name = os.getenv("MODEL_CLASS_NAME")
model_save_path = Path(__file__).parent.joinpath("model.json")
native_model = load_xgboost_from_json(
model_save_path=str(model_save_path), class_name=model_class_name
)
schema_builder = obj[1]
else:
native_model, schema_builder = obj
if native_model:
framework, _ = _detect_framework_and_version(
model_base=str(_get_model_base(model=native_model))
)
if framework == "pytorch":
native_model.eval()
return native_model if callable(native_model) else native_model.predict
elif inference_spec:
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))


def input_fn(input_data, content_type):
"""Placeholder docstring"""
try:
if hasattr(schema_builder, "custom_input_translator"):
return schema_builder.custom_input_translator.deserialize(
io.BytesIO(input_data), content_type
)
else:
return schema_builder.input_deserializer.deserialize(
io.BytesIO(input_data), content_type[0]
)
except Exception as e:
raise Exception("Encountered error in deserialize_request.") from e


def predict_fn(input_data, predict_callable):
"""Placeholder docstring"""
return predict_callable(input_data)


def output_fn(predictions, accept_type):
"""Placeholder docstring"""
try:
if hasattr(schema_builder, "custom_output_translator"):
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
else:
return schema_builder.output_serializer.serialize(predictions)
except Exception as e:
logger.error("Encountered error: %s in serialize_response." % e)
raise Exception("Encountered error in serialize_response.") from e


def _run_preflight_diagnostics():
install_package("sagemaker")
install_package("boto3", "1.17.52")
_py_vs_parity_check()
_pickle_file_integrity_check()


def _py_vs_parity_check():
container_py_vs = platform.python_version()
local_py_vs = os.getenv("LOCAL_PYTHON")

if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
logger.warning(
f"The local python version {local_py_vs} differs from the python version "
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
)


def _pickle_file_integrity_check():
from sagemaker.serve.validations.check_integrity import perform_integrity_check

with open("/opt/ml/model/code/serve.pkl", "rb") as f:
buffer = f.read()

metadeata_path = Path("/opt/ml/model/code/metadata.json")
perform_integrity_check(buffer=buffer, metadata_path=metadeata_path)


def install_package(package_name, version=None):
"""Placeholder docstring"""
if version:
command = f"pip install {package_name}=={version}"
else:
command = f"pip install {package_name}"

try:
subprocess.check_call(command, shell=True)
print(f"Successfully installed {package_name} using install_package")
except subprocess.CalledProcessError as e:
print(f"Failed to install {package_name}. Error: {e}")


# on import, execute
_run_preflight_diagnostics()
24 changes: 16 additions & 8 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
mock_detect_fw_version.return_value = framework, version

mock_prepare_for_torchserve.side_effect = (
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
if model_path == MODEL_PATH
and shared_libs == []
and dependencies == {"auto": False}
and session == session
and image_uri == mock_image_uri
and inference_spec is None
else None
)
Expand Down Expand Up @@ -248,11 +249,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
mock_detect_fw_version.return_value = framework, version

mock_prepare_for_torchserve.side_effect = (
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
if model_path == MODEL_PATH
and shared_libs == []
and dependencies == {"auto": False}
and session == session
and image_uri == mock_1p_dlc_image_uri
and inference_spec is None
else None
)
Expand Down Expand Up @@ -352,11 +354,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
)

mock_prepare_for_torchserve.side_effect = (
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
if model_path == MODEL_PATH
and shared_libs == []
and dependencies == {"auto": False}
and session == mock_session
and image_uri == mock_image_uri
and inference_spec == mock_inference_spec
else None
)
Expand Down Expand Up @@ -447,11 +450,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
mock_detect_fw_version.return_value = framework, version

mock_prepare_for_torchserve.side_effect = (
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
if model_path == MODEL_PATH
and shared_libs == []
and dependencies == {"auto": False}
and session == session
and image_uri == mock_image_uri
and inference_spec is None
else None
)
Expand Down Expand Up @@ -550,11 +554,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
mock_detect_fw_version.return_value = "xgboost", version

mock_prepare_for_torchserve.side_effect = (
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
if model_path == MODEL_PATH
and shared_libs == []
and dependencies == {"auto": False}
and session == session
and image_uri == mock_image_uri
and inference_spec is None
else None
)
Expand Down Expand Up @@ -655,11 +660,12 @@ def test_build_happy_path_with_local_container_mode(
)

mock_prepare_for_torchserve.side_effect = (
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
if model_path == MODEL_PATH
and shared_libs == []
and dependencies == {"auto": False}
and session == mock_session
and image_uri == mock_image_uri
and inference_spec == mock_inference_spec
else None
)
Expand Down Expand Up @@ -752,11 +758,12 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
)

mock_prepare_for_torchserve.side_effect = (
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
if model_path == MODEL_PATH
and shared_libs == []
and dependencies == {"auto": False}
and session == mock_session
and image_uri == mock_image_uri
and inference_spec == mock_inference_spec
else None
)
Expand Down Expand Up @@ -887,12 +894,13 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
)

mock_prepare_for_torchserve.side_effect = (
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
lambda model_path, shared_libs, dependencies, image_uri, session, inference_spec: mock_secret_key
if model_path == MODEL_PATH
and shared_libs == []
and dependencies == {"auto": False}
and session == mock_session
and inference_spec is None
and image_uri == mock_image_uri
else None
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_capture_dependencies(self, mock_subprocess, mock_file, mock_path):
call("custom_module==1.2.3\n"),
call("numpy==4.5\n"),
call("boto3=1.28.*\n"),
call("sagemaker\n"),
call("sagemaker>=2.199\n"),
call("other_module@http://some/website.whl\n"),
]
mocked_writes.assert_has_calls(expected_calls)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
SHARED_LIBS = ["/path/to/shared/libs"]
DEPENDENCIES = "dependencies"
INFERENCE_SPEC = Mock()
IMAGE_URI = "mock_image_uri"
XGB_1P_IMAGE_URI = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.7-1"
INFERENCE_SPEC.prepare = Mock(return_value=None)

SECRET_KEY = "secret-key"
Expand All @@ -29,6 +31,9 @@


class PrepareForTorchServeTests(TestCase):
def setUp(self):
INFERENCE_SPEC.reset_mock()

@patch("builtins.open", new_callable=mock_open, read_data=b"{}")
@patch("sagemaker.serve.model_server.torchserve.prepare._MetaData")
@patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash")
Expand Down Expand Up @@ -58,9 +63,50 @@ def test_prepare_happy(
shared_libs=SHARED_LIBS,
dependencies=DEPENDENCIES,
session=mock_session,
image_uri=IMAGE_URI,
inference_spec=INFERENCE_SPEC,
)

mock_path_instance.mkdir.assert_not_called()
INFERENCE_SPEC.prepare.assert_called_once()
self.assertEqual(secret_key, SECRET_KEY)

@patch("os.rename")
@patch("builtins.open", new_callable=mock_open, read_data=b"{}")
@patch("sagemaker.serve.model_server.torchserve.prepare._MetaData")
@patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash")
@patch("sagemaker.serve.model_server.torchserve.prepare.generate_secret_key")
@patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies")
@patch("sagemaker.serve.model_server.torchserve.prepare.shutil")
@patch("sagemaker.serve.model_server.torchserve.prepare.Path")
def test_prepare_happy_xgboost(
self,
mock_path,
mock_shutil,
mock_capture_dependencies,
mock_generate_secret_key,
mock_compute_hash,
mock_metadata,
mock_open,
mock_rename,
):

mock_path_instance = mock_path.return_value
mock_path_instance.exists.return_value = True
mock_path_instance.joinpath.return_value = Mock()

mock_generate_secret_key.return_value = SECRET_KEY

secret_key = prepare_for_torchserve(
model_path=MODEL_PATH,
shared_libs=SHARED_LIBS,
dependencies=DEPENDENCIES,
session=mock_session,
image_uri=XGB_1P_IMAGE_URI,
inference_spec=INFERENCE_SPEC,
)

mock_rename.assert_called_once()
mock_path_instance.mkdir.assert_not_called()
INFERENCE_SPEC.prepare.assert_called_once()
self.assertEqual(secret_key, SECRET_KEY)