Skip to content

Commit 3ed0011

Browse files
SSRraymondRaymond Liu
andauthored
fix: mitigation of xgboost container incompatibility with new version (#4298)
Co-authored-by: Raymond Liu <[email protected]>
1 parent 81ca3a0 commit 3ed0011

File tree

7 files changed

+213
-14
lines changed

7 files changed

+213
-14
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ def _build_for_torchserve(self) -> Type[Model]:
514514
shared_libs=self.shared_libs,
515515
dependencies=self.dependencies,
516516
session=self.sagemaker_session,
517+
image_uri=self.image_uri,
517518
inference_spec=self.inference_spec,
518519
)
519520

src/sagemaker/serve/detector/dependency_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool =
5454

5555
with open(path, "r") as f:
5656
autodetect_depedencies = f.read().splitlines()
57-
autodetect_depedencies.append("sagemaker")
57+
autodetect_depedencies.append("sagemaker>=2.199")
5858
else:
59-
autodetect_depedencies = ["sagemaker"]
59+
autodetect_depedencies = ["sagemaker>=2.199"]
6060

6161
module_version_dict = _parse_dependency_list(autodetect_depedencies)
6262

src/sagemaker/serve/model_server/torchserve/prepare.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
from __future__ import absolute_import
7+
import os
78
from pathlib import Path
89
import shutil
910
from typing import List
@@ -15,7 +16,7 @@
1516
generate_secret_key,
1617
compute_hash,
1718
)
18-
19+
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
1920
from sagemaker.remote_function.core.serialization import _MetaData
2021

2122

@@ -24,6 +25,7 @@ def prepare_for_torchserve(
2425
shared_libs: List[str],
2526
dependencies: dict,
2627
session: Session,
28+
image_uri: str,
2729
inference_spec: InferenceSpec = None,
2830
) -> str:
2931
"""This is a one-line summary of the function.
@@ -51,8 +53,14 @@ def prepare_for_torchserve(
5153

5254
code_dir = model_path.joinpath("code")
5355
code_dir.mkdir(exist_ok=True)
54-
55-
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)
56+
# https://github.com/aws/sagemaker-python-sdk/issues/4288
57+
if is_1p_image_uri(image_uri=image_uri) and "xgboost" in image_uri:
58+
shutil.copy2(Path(__file__).parent.joinpath("xgboost_inference.py"), code_dir)
59+
os.rename(
60+
str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py"))
61+
)
62+
else:
63+
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)
5664

5765
shared_libs_dir = model_path.joinpath("shared_libs")
5866
shared_libs_dir.mkdir(exist_ok=True)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""This module is for SageMaker inference.py."""
2+
from __future__ import absolute_import
3+
import os
4+
import io
5+
import subprocess
6+
import cloudpickle
7+
import shutil
8+
import platform
9+
from pathlib import Path
10+
from functools import partial
11+
import logging
12+
13+
logger = logging.getLogger(__name__)
14+
15+
inference_spec = None
16+
native_model = None
17+
schema_builder = None
18+
19+
20+
def model_fn(model_dir):
21+
"""Placeholder docstring"""
22+
from sagemaker.serve.spec.inference_spec import InferenceSpec
23+
from sagemaker.serve.detector.image_detector import (
24+
_detect_framework_and_version,
25+
_get_model_base,
26+
)
27+
from sagemaker.serve.detector.pickler import load_xgboost_from_json
28+
29+
shared_libs_path = Path(model_dir + "/shared_libs")
30+
31+
if shared_libs_path.exists():
32+
# before importing, place dynamic linked libraries in shared lib path
33+
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)
34+
35+
serve_path = Path(__file__).parent.joinpath("serve.pkl")
36+
with open(str(serve_path), mode="rb") as file:
37+
global inference_spec, native_model, schema_builder
38+
obj = cloudpickle.load(file)
39+
if isinstance(obj[0], InferenceSpec):
40+
inference_spec, schema_builder = obj
41+
elif isinstance(obj[0], str) and obj[0] == "xgboost":
42+
model_class_name = os.getenv("MODEL_CLASS_NAME")
43+
model_save_path = Path(__file__).parent.joinpath("model.json")
44+
native_model = load_xgboost_from_json(
45+
model_save_path=str(model_save_path), class_name=model_class_name
46+
)
47+
schema_builder = obj[1]
48+
else:
49+
native_model, schema_builder = obj
50+
if native_model:
51+
framework, _ = _detect_framework_and_version(
52+
model_base=str(_get_model_base(model=native_model))
53+
)
54+
if framework == "pytorch":
55+
native_model.eval()
56+
return native_model if callable(native_model) else native_model.predict
57+
elif inference_spec:
58+
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
59+
60+
61+
def input_fn(input_data, content_type):
62+
"""Placeholder docstring"""
63+
try:
64+
if hasattr(schema_builder, "custom_input_translator"):
65+
return schema_builder.custom_input_translator.deserialize(
66+
io.BytesIO(input_data), content_type
67+
)
68+
else:
69+
return schema_builder.input_deserializer.deserialize(
70+
io.BytesIO(input_data), content_type[0]
71+
)
72+
except Exception as e:
73+
raise Exception("Encountered error in deserialize_request.") from e
74+
75+
76+
def predict_fn(input_data, predict_callable):
77+
"""Placeholder docstring"""
78+
return predict_callable(input_data)
79+
80+
81+
def output_fn(predictions, accept_type):
82+
"""Placeholder docstring"""
83+
try:
84+
if hasattr(schema_builder, "custom_output_translator"):
85+
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
86+
else:
87+
return schema_builder.output_serializer.serialize(predictions)
88+
except Exception as e:
89+
logger.error("Encountered error: %s in serialize_response." % e)
90+
raise Exception("Encountered error in serialize_response.") from e
91+
92+
93+
def _run_preflight_diagnostics():
94+
install_package("sagemaker")
95+
install_package("boto3", "1.17.52")
96+
_py_vs_parity_check()
97+
_pickle_file_integrity_check()
98+
99+
100+
def _py_vs_parity_check():
101+
container_py_vs = platform.python_version()
102+
local_py_vs = os.getenv("LOCAL_PYTHON")
103+
104+
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
105+
logger.warning(
106+
f"The local python version {local_py_vs} differs from the python version "
107+
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
108+
)
109+
110+
111+
def _pickle_file_integrity_check():
112+
from sagemaker.serve.validations.check_integrity import perform_integrity_check
113+
114+
with open("/opt/ml/model/code/serve.pkl", "rb") as f:
115+
buffer = f.read()
116+
117+
metadeata_path = Path("/opt/ml/model/code/metadata.json")
118+
perform_integrity_check(buffer=buffer, metadata_path=metadeata_path)
119+
120+
121+
def install_package(package_name, version=None):
122+
"""Placeholder docstring"""
123+
if version:
124+
command = f"pip install {package_name}=={version}"
125+
else:
126+
command = f"pip install {package_name}"
127+
128+
try:
129+
subprocess.check_call(command, shell=True)
130+
print(f"Successfully installed {package_name} using install_package")
131+
except subprocess.CalledProcessError as e:
132+
print(f"Failed to install {package_name}. Error: {e}")
133+
134+
135+
# on import, execute
136+
_run_preflight_diagnostics()

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
149149
mock_detect_fw_version.return_value = framework, version
150150

151151
mock_prepare_for_torchserve.side_effect = (
152-
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
152+
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
153153
if model_path == MODEL_PATH
154154
and shared_libs == []
155155
and dependencies == {"auto": False}
156156
and session == session
157+
and image_uri == mock_image_uri
157158
and inference_spec is None
158159
else None
159160
)
@@ -248,11 +249,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
248249
mock_detect_fw_version.return_value = framework, version
249250

250251
mock_prepare_for_torchserve.side_effect = (
251-
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
252+
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
252253
if model_path == MODEL_PATH
253254
and shared_libs == []
254255
and dependencies == {"auto": False}
255256
and session == session
257+
and image_uri == mock_1p_dlc_image_uri
256258
and inference_spec is None
257259
else None
258260
)
@@ -352,11 +354,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
352354
)
353355

354356
mock_prepare_for_torchserve.side_effect = (
355-
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
357+
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
356358
if model_path == MODEL_PATH
357359
and shared_libs == []
358360
and dependencies == {"auto": False}
359361
and session == mock_session
362+
and image_uri == mock_image_uri
360363
and inference_spec == mock_inference_spec
361364
else None
362365
)
@@ -447,11 +450,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
447450
mock_detect_fw_version.return_value = framework, version
448451

449452
mock_prepare_for_torchserve.side_effect = (
450-
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
453+
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
451454
if model_path == MODEL_PATH
452455
and shared_libs == []
453456
and dependencies == {"auto": False}
454457
and session == session
458+
and image_uri == mock_image_uri
455459
and inference_spec is None
456460
else None
457461
)
@@ -550,11 +554,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
550554
mock_detect_fw_version.return_value = "xgboost", version
551555

552556
mock_prepare_for_torchserve.side_effect = (
553-
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
557+
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
554558
if model_path == MODEL_PATH
555559
and shared_libs == []
556560
and dependencies == {"auto": False}
557561
and session == session
562+
and image_uri == mock_image_uri
558563
and inference_spec is None
559564
else None
560565
)
@@ -655,11 +660,12 @@ def test_build_happy_path_with_local_container_mode(
655660
)
656661

657662
mock_prepare_for_torchserve.side_effect = (
658-
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
663+
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
659664
if model_path == MODEL_PATH
660665
and shared_libs == []
661666
and dependencies == {"auto": False}
662667
and session == mock_session
668+
and image_uri == mock_image_uri
663669
and inference_spec == mock_inference_spec
664670
else None
665671
)
@@ -752,11 +758,12 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
752758
)
753759

754760
mock_prepare_for_torchserve.side_effect = (
755-
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
761+
lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key
756762
if model_path == MODEL_PATH
757763
and shared_libs == []
758764
and dependencies == {"auto": False}
759765
and session == mock_session
766+
and image_uri == mock_image_uri
760767
and inference_spec == mock_inference_spec
761768
else None
762769
)
@@ -887,12 +894,13 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
887894
)
888895

889896
mock_prepare_for_torchserve.side_effect = (
890-
lambda model_path, shared_libs, dependencies, session, inference_spec: mock_secret_key
897+
lambda model_path, shared_libs, dependencies, image_uri, session, inference_spec: mock_secret_key
891898
if model_path == MODEL_PATH
892899
and shared_libs == []
893900
and dependencies == {"auto": False}
894901
and session == mock_session
895902
and inference_spec is None
903+
and image_uri == mock_image_uri
896904
else None
897905
)
898906

tests/unit/sagemaker/serve/detector/test_dependency_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_capture_dependencies(self, mock_subprocess, mock_file, mock_path):
9999
call("custom_module==1.2.3\n"),
100100
call("numpy==4.5\n"),
101101
call("boto3=1.28.*\n"),
102-
call("sagemaker\n"),
102+
call("sagemaker>=2.199\n"),
103103
call("other_module@http://some/website.whl\n"),
104104
]
105105
mocked_writes.assert_has_calls(expected_calls)

tests/unit/sagemaker/serve/model_server/torchserve/test_prepare.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
SHARED_LIBS = ["/path/to/shared/libs"]
2222
DEPENDENCIES = "dependencies"
2323
INFERENCE_SPEC = Mock()
24+
IMAGE_URI = "mock_image_uri"
25+
XGB_1P_IMAGE_URI = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.7-1"
2426
INFERENCE_SPEC.prepare = Mock(return_value=None)
2527

2628
SECRET_KEY = "secret-key"
@@ -29,6 +31,9 @@
2931

3032

3133
class PrepareForTorchServeTests(TestCase):
34+
def setUp(self):
35+
INFERENCE_SPEC.reset_mock()
36+
3237
@patch("builtins.open", new_callable=mock_open, read_data=b"{}")
3338
@patch("sagemaker.serve.model_server.torchserve.prepare._MetaData")
3439
@patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash")
@@ -58,9 +63,50 @@ def test_prepare_happy(
5863
shared_libs=SHARED_LIBS,
5964
dependencies=DEPENDENCIES,
6065
session=mock_session,
66+
image_uri=IMAGE_URI,
67+
inference_spec=INFERENCE_SPEC,
68+
)
69+
70+
mock_path_instance.mkdir.assert_not_called()
71+
INFERENCE_SPEC.prepare.assert_called_once()
72+
self.assertEqual(secret_key, SECRET_KEY)
73+
74+
@patch("os.rename")
75+
@patch("builtins.open", new_callable=mock_open, read_data=b"{}")
76+
@patch("sagemaker.serve.model_server.torchserve.prepare._MetaData")
77+
@patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash")
78+
@patch("sagemaker.serve.model_server.torchserve.prepare.generate_secret_key")
79+
@patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies")
80+
@patch("sagemaker.serve.model_server.torchserve.prepare.shutil")
81+
@patch("sagemaker.serve.model_server.torchserve.prepare.Path")
82+
def test_prepare_happy_xgboost(
83+
self,
84+
mock_path,
85+
mock_shutil,
86+
mock_capture_dependencies,
87+
mock_generate_secret_key,
88+
mock_compute_hash,
89+
mock_metadata,
90+
mock_open,
91+
mock_rename,
92+
):
93+
94+
mock_path_instance = mock_path.return_value
95+
mock_path_instance.exists.return_value = True
96+
mock_path_instance.joinpath.return_value = Mock()
97+
98+
mock_generate_secret_key.return_value = SECRET_KEY
99+
100+
secret_key = prepare_for_torchserve(
101+
model_path=MODEL_PATH,
102+
shared_libs=SHARED_LIBS,
103+
dependencies=DEPENDENCIES,
104+
session=mock_session,
105+
image_uri=XGB_1P_IMAGE_URI,
61106
inference_spec=INFERENCE_SPEC,
62107
)
63108

109+
mock_rename.assert_called_once()
64110
mock_path_instance.mkdir.assert_not_called()
65111
INFERENCE_SPEC.prepare.assert_called_once()
66112
self.assertEqual(secret_key, SECRET_KEY)

0 commit comments

Comments
 (0)