Skip to content

Commit f08be97

Browse files
gwang111EC2 Default User
andauthored
Fix: Updated js mb compression logic - ModelBuilder (#4294)
Co-authored-by: EC2 Default User <[email protected]>
1 parent d756d4d commit f08be97

File tree

8 files changed

+568
-29
lines changed

8 files changed

+568
-29
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def _is_jumpstart_model_id(self) -> bool:
9595
def _create_pre_trained_js_model(self) -> Type[Model]:
9696
"""Placeholder docstring"""
9797
pysdk_model = JumpStartModel(self.model)
98+
pysdk_model.sagemaker_session = self.sagemaker_session
9899

99100
self._original_deploy = pysdk_model.deploy
100101
pysdk_model.deploy = self._js_builder_deploy_wrapper

src/sagemaker/serve/builder/tgi_builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ def _create_tgi_model(self) -> Type[Model]:
133133
logger.info("Auto detected %s. Proceeding with the the deployment.", self.image_uri)
134134

135135
pysdk_model = HuggingFaceModel(
136-
image_uri=self.image_uri, env=self.env_vars, role=self.role_arn
136+
image_uri=self.image_uri,
137+
env=self.env_vars,
138+
role=self.role_arn,
139+
sagemaker_session=self.sagemaker_session,
137140
)
138141

139142
self._original_deploy = pysdk_model.deploy

src/sagemaker/serve/model_server/djl_serving/prepare.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
from __future__ import absolute_import
1616
import shutil
17-
import tarfile
18-
import subprocess
1917
import json
18+
import tarfile
2019
import logging
2120
from typing import List
2221
from pathlib import Path
2322

2423
from sagemaker.utils import _tmpdir
24+
from sagemaker.s3 import S3Downloader
2525
from sagemaker.djl_inference import DJLModel
2626
from sagemaker.djl_inference.model import _read_existing_serving_properties
2727
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
@@ -34,27 +34,57 @@
3434

3535

3636
def _has_serving_properties_file(code_dir: Path) -> bool:
37-
"""Placeholder Docstring"""
37+
"""Check for existing serving properties in the directory"""
3838
return code_dir.joinpath(_SERVING_PROPERTIES_FILE).is_file()
3939

4040

41-
def _members(resources: object, depth: int):
42-
"""Placeholder Docstring"""
43-
for member in resources.getmembers():
44-
member.path = member.path.split("/", depth)[-1]
45-
yield member
41+
def _move_to_code_dir(js_model_dir: str, code_dir: Path):
42+
"""Move DJL Jumpstart resources from model to code_dir"""
43+
js_model_resources = Path(js_model_dir).joinpath("model")
44+
for resource in js_model_resources.glob("*"):
45+
try:
46+
shutil.move(resource, code_dir)
47+
except shutil.Error as e:
48+
if "already exists" in str(e):
49+
continue
50+
51+
52+
def _extract_js_resource(js_model_dir: str, js_id: str):
53+
"""Uncompress the jumpstart resource"""
54+
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
55+
with tarfile.open(str(tmp_sourcedir)) as resources:
56+
resources.extractall(path=js_model_dir)
4657

4758

4859
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):
49-
"""Placeholder Docstring"""
60+
"""Copy the associated JumpStart Resource into the code directory"""
5061
logger.info("Downloading JumpStart artifacts from S3...")
51-
with _tmpdir(directory=str(code_dir)) as js_model_dir:
52-
subprocess.run(["aws", "s3", "cp", model_data, js_model_dir])
5362

54-
logger.info("Uncompressing JumpStart artifacts for faster loading...")
55-
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
56-
with tarfile.open(str(tmp_sourcedir)) as resources:
57-
resources.extractall(path=code_dir, members=_members(resources, 1))
63+
s3_downloader = S3Downloader()
64+
invalid_model_data_format = False
65+
with _tmpdir(directory=str(code_dir)) as js_model_dir:
66+
if isinstance(model_data, str):
67+
if model_data.endswith(".tar.gz"):
68+
logger.info("Uncompressing JumpStart artifacts for faster loading...")
69+
s3_downloader.download(model_data, js_model_dir)
70+
_extract_js_resource(js_model_dir, js_id)
71+
else:
72+
logger.info("Copying uncompressed JumpStart artifacts...")
73+
s3_downloader.download(model_data, js_model_dir)
74+
elif (
75+
isinstance(model_data, dict)
76+
and model_data.get("S3DataSource")
77+
and model_data.get("S3DataSource").get("S3Uri")
78+
):
79+
logger.info("Copying uncompressed JumpStart artifacts...")
80+
s3_downloader.download(model_data.get("S3DataSource").get("S3Uri"), js_model_dir)
81+
else:
82+
invalid_model_data_format = True
83+
if not invalid_model_data_format:
84+
_move_to_code_dir(js_model_dir, code_dir)
85+
86+
if invalid_model_data_format:
87+
raise ValueError("JumpStart model data compression format is unsupported: %s", model_data)
5888

5989
existing_properties = _read_existing_serving_properties(code_dir)
6090
config_json_file = code_dir.joinpath("config.json")
@@ -70,7 +100,7 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):
70100
def _generate_properties_file(
71101
model: DJLModel, code_dir: Path, overwrite_props_from_file: bool, manual_set_props: dict
72102
):
73-
"""Placeholder Docstring"""
103+
"""Construct serving properties file taking into account of overrides or manual specs"""
74104
if _has_serving_properties_file(code_dir):
75105
existing_properties = _read_existing_serving_properties(code_dir)
76106
else:

src/sagemaker/serve/model_server/tgi/prepare.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,66 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
113
"""Prepare TgiModel for Deployment"""
214

315
from __future__ import absolute_import
416
import tarfile
5-
import subprocess
617
import logging
718
from typing import List
819
from pathlib import Path
920

1021
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
1122
from sagemaker.utils import _tmpdir
23+
from sagemaker.s3 import S3Downloader
1224

1325
logger = logging.getLogger(__name__)
1426

1527

28+
def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
29+
"""Uncompress the jumpstart resource"""
30+
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
31+
with tarfile.open(str(tmp_sourcedir)) as resources:
32+
resources.extractall(path=code_dir)
33+
34+
1635
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:
17-
"""Placeholder Docstring"""
36+
"""Copy the associated JumpStart Resource into the code directory"""
1837
logger.info("Downloading JumpStart artifacts from S3...")
19-
with _tmpdir(directory=str(code_dir)) as js_model_dir:
20-
js_model_data_loc = model_data.get("S3DataSource").get("S3Uri")
21-
# TODO: leave this check here until we are sure every js model has moved to uncompressed
22-
if js_model_data_loc.endswith("tar.gz"):
23-
subprocess.run(["aws", "s3", "cp", js_model_data_loc, js_model_dir])
38+
39+
s3_downloader = S3Downloader()
40+
if isinstance(model_data, str):
41+
if model_data.endswith(".tar.gz"):
2442
logger.info("Uncompressing JumpStart artifacts for faster loading...")
25-
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
26-
with tarfile.open(str(tmp_sourcedir)) as resources:
27-
resources.extractall(path=code_dir)
43+
with _tmpdir(directory=str(code_dir)) as js_model_dir:
44+
s3_downloader.download(model_data, js_model_dir)
45+
_extract_js_resource(js_model_dir, code_dir, js_id)
2846
else:
29-
subprocess.run(["aws", "s3", "cp", js_model_data_loc, js_model_dir, "--recursive"])
47+
logger.info("Copying uncompressed JumpStart artifacts...")
48+
s3_downloader.download(model_data, code_dir)
49+
elif (
50+
isinstance(model_data, dict)
51+
and model_data.get("S3DataSource")
52+
and model_data.get("S3DataSource").get("S3Uri")
53+
):
54+
logger.info("Copying uncompressed JumpStart artifacts...")
55+
s3_downloader.download(model_data.get("S3DataSource").get("S3Uri"), code_dir)
56+
else:
57+
raise ValueError("JumpStart model data compression format is unsupported: %s", model_data)
58+
3059
return True
3160

3261

3362
def _create_dir_structure(model_path: str) -> tuple:
34-
"""Placeholder Docstring"""
63+
"""Create the expected model directory structure for the TGI server"""
3564
model_path = Path(model_path)
3665
if not model_path.exists():
3766
model_path.mkdir(parents=True)

tests/unit/sagemaker/serve/builder/test_djl_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,12 @@ def test_build_deploy_for_djl_local_container(
114114
mode=Mode.LOCAL_CONTAINER,
115115
model_server=ModelServer.DJL_SERVING,
116116
)
117+
117118
builder._prepare_for_mode = MagicMock()
118119
builder._prepare_for_mode.side_effect = None
119120

120121
model = builder.build()
122+
builder.serve_settings.telemetry_opt_out = True
121123

122124
assert isinstance(model, HuggingFaceAccelerateModel)
123125
assert (
@@ -176,6 +178,7 @@ def test_build_for_djl_local_container_faster_transformer(
176178
model_server=ModelServer.DJL_SERVING,
177179
)
178180
model = builder.build()
181+
builder.serve_settings.telemetry_opt_out = True
179182

180183
assert isinstance(model, FasterTransformerModel)
181184
assert (
@@ -211,6 +214,7 @@ def test_build_for_djl_local_container_deepspeed(
211214
model_server=ModelServer.DJL_SERVING,
212215
)
213216
model = builder.build()
217+
builder.serve_settings.telemetry_opt_out = True
214218

215219
assert isinstance(model, DeepSpeedModel)
216220
assert model.generate_serving_properties() == mock_expected_deepspeed_serving_properties
@@ -268,6 +272,7 @@ def test_tune_for_djl_local_container(
268272
builder._djl_model_builder_deploy_wrapper = MagicMock()
269273

270274
model = builder.build()
275+
builder.serve_settings.telemetry_opt_out = True
271276
tuned_model = model.tune()
272277
assert tuned_model.generate_serving_properties() == mock_most_performant_serving_properties
273278

@@ -317,6 +322,7 @@ def test_tune_for_djl_local_container_deep_ping_ex(
317322
builder._prepare_for_mode.side_effect = None
318323

319324
model = builder.build()
325+
builder.serve_settings.telemetry_opt_out = True
320326
tuned_model = model.tune()
321327
assert (
322328
tuned_model.generate_serving_properties()
@@ -369,6 +375,7 @@ def test_tune_for_djl_local_container_load_ex(
369375
builder._prepare_for_mode.side_effect = None
370376

371377
model = builder.build()
378+
builder.serve_settings.telemetry_opt_out = True
372379
tuned_model = model.tune()
373380
assert (
374381
tuned_model.generate_serving_properties()
@@ -421,6 +428,7 @@ def test_tune_for_djl_local_container_oom_ex(
421428
builder._prepare_for_mode.side_effect = None
422429

423430
model = builder.build()
431+
builder.serve_settings.telemetry_opt_out = True
424432
tuned_model = model.tune()
425433
assert (
426434
tuned_model.generate_serving_properties()
@@ -473,6 +481,7 @@ def test_tune_for_djl_local_container_invoke_ex(
473481
builder._prepare_for_mode.side_effect = None
474482

475483
model = builder.build()
484+
builder.serve_settings.telemetry_opt_out = True
476485
tuned_model = model.tune()
477486
assert (
478487
tuned_model.generate_serving_properties()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
MOCK_MODEL_PATH = "/path/to/mock/model/dir"
16+
MOCK_CODE_DIR = "/path/to/mock/model/dir/code"
17+
MOCK_JUMPSTART_ID = "mock_llm_js_id"
18+
MOCK_TMP_DIR = "tmp123456"
19+
MOCK_COMPRESSED_MODEL_DATA_STR = (
20+
"s3://jumpstart-cache/to/infer-prepack-huggingface-llm-falcon-7b-bf16.tar.gz"
21+
)
22+
MOCK_UNCOMPRESSED_MODEL_DATA_STR = "s3://jumpstart-cache/to/artifacts/inference-prepack/v1.0.1/"
23+
MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT = (
24+
"s3://jumpstart-cache/to/artifacts/inference-prepack/v1.0.1/dict/"
25+
)
26+
MOCK_UNCOMPRESSED_MODEL_DATA_DICT = {
27+
"S3DataSource": {
28+
"S3Uri": MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT,
29+
"S3DataType": "S3Prefix",
30+
"CompressionType": "None",
31+
}
32+
}
33+
MOCK_INVALID_MODEL_DATA_DICT = {}

0 commit comments

Comments
 (0)