Skip to content

Commit 35e6c47

Browse files
committed
Generate container definition for deepspeed models
1 parent 3c0a4f4 commit 35e6c47

File tree

3 files changed

+90
-62
lines changed

3 files changed

+90
-62
lines changed

src/sagemaker/djl_inference/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.djl_inference.model import DJLLargeModel, DJLLargeModelPredictor, DeepSpeedModel, HuggingfaceAccelerateModel
16+
from sagemaker.djl_inference.model import (
17+
DJLLargeModel,
18+
DJLLargeModelPredictor,
19+
DeepSpeedModel,
20+
HuggingfaceAccelerateModel,
21+
)

src/sagemaker/djl_inference/defaults.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"bloom",
1818
"opt",
1919
"gpt_neox",
20-
#"gptj",
20+
# "gptj",
2121
"gpt_neo",
2222
"gpt2",
2323
"xlm-roberta",
@@ -35,4 +35,4 @@
3535
"xlm-roberta",
3636
"roberta",
3737
"bert",
38-
}
38+
}

src/sagemaker/djl_inference/model.py

Lines changed: 82 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515

1616
import json
1717
import os.path
18-
import tempfile
1918
from enum import Enum
2019
from typing import Optional, Union, Dict
2120

21+
import defaults
2222
import sagemaker
2323
from sagemaker import s3, Predictor, image_uris, fw_utils
2424
from sagemaker.deserializers import JSONDeserializer
2525
from sagemaker.model import FrameworkModel
2626
from sagemaker.serializers import JSONSerializer
2727
from sagemaker.session import Session
28-
import defaults
28+
from sagemaker.utils import _tmpdir, _create_or_update_code_dir
2929
from sagemaker.workflow.entities import PipelineVariable
3030

3131

@@ -36,7 +36,6 @@ class DJLEngine(Enum):
3636

3737

3838
class DJLLargeModelPredictor(Predictor):
39-
4039
def __init__(
4140
self,
4241
endpoint_name,
@@ -53,7 +52,6 @@ def __init__(
5352

5453

5554
class DJLLargeModel(FrameworkModel):
56-
5755
def __new__(
5856
cls,
5957
uncompressed_model_data: str,
@@ -62,6 +60,11 @@ def __new__(
6260
):
6361
if not uncompressed_model_data.startswith("s3://"):
6462
raise ValueError("DJLLargeModel only supports loading model artifacts from s3")
63+
if uncompressed_model_data.endswith("tar.gz"):
64+
raise ValueError(
65+
"DJLLargeModel does not support model artifacts in tar.gz format."
66+
"Please store the model in uncompressed format and provide the s3 uri of the folder"
67+
)
6568
if uncompressed_model_data.endswith("/"):
6669
config_file = uncompressed_model_data + "config.json"
6770
else:
@@ -88,7 +91,7 @@ def __init__(
8891
entry_point: Optional[str] = None,
8992
image_uri: Optional[Union[str, PipelineVariable]] = None,
9093
predictor_cls: callable = DJLLargeModelPredictor,
91-
**kwargs
94+
**kwargs,
9295
):
9396
self.uncompressed_model_data = uncompressed_model_data
9497
self.djl_version = djl_version
@@ -102,7 +105,7 @@ def __init__(
102105
self.model_loading_timeout = model_loading_timeout
103106
self.prediction_timeout = prediction_timeout
104107
super(DJLLargeModel, self).__init__(
105-
None, image_uri, role, entry_point, predictor_cls=predictor_cls,**kwargs
108+
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
106109
)
107110
self.sagemaker_session = self.sagemaker_session or Session()
108111

@@ -122,14 +125,28 @@ def _determine_engine_for_model_type(model_type: str):
122125
return DeepSpeedModel
123126
return HuggingfaceAccelerateModel
124127

128+
125129
def _validate_engine_for_model_type(model_type: str, engine: DJLEngine):
126130
if engine == DJLEngine.DEEPSPEED:
127131
if model_type not in defaults.DEEPSPEED_SUPPORTED_ARCHITECTURES:
128-
raise ValueError(f"{model_type} is not supported by DeepSpeed. " \
129-
f"Supported model_types are {defaults.DEEPSPEED_SUPPORTED_ARCHITECTURES}")
132+
raise ValueError(
133+
f"{model_type} is not supported by DeepSpeed. "
134+
f"Supported model_types are {defaults.DEEPSPEED_SUPPORTED_ARCHITECTURES}"
135+
)
136+
137+
138+
def _read_existing_serving_properties(directory: str):
139+
serving_properties_path = os.path.join(directory, "serving.properties")
140+
properties = {}
141+
if os.path.exists(serving_properties_path):
142+
with open(serving_properties_path, "r") as f:
143+
for line in f:
144+
key, val = line.split("=", 1)
145+
properties[key] = val
146+
return properties
130147

131-
class DeepSpeedModel(DJLLargeModel):
132148

149+
class DeepSpeedModel(DJLLargeModel):
133150
_framework_name = "djl-deepspeed"
134151

135152
def __init__(
@@ -139,8 +156,8 @@ def __init__(
139156
low_cpu_mem_usage: bool = True,
140157
enable_cuda_graph: bool = False,
141158
triangular_masking: bool = True,
142-
return_tuple = True,
143-
deepspeed_checkpoint_file = None,
159+
return_tuple=True,
160+
deepspeed_checkpoint_file=None,
144161
task: str = None,
145162
data_type: str = None,
146163
tensor_parallel_degree: int = None,
@@ -185,19 +202,52 @@ def prepare_container_def(
185202
deploy_image = self.serving_image_uri(region_name)
186203

187204
print(f"Deploy image is{deploy_image}")
188-
tmp_dir = self._validate_and_write_serving_properties()
189-
deploy_key_prefix = fw_utils.model_code_key_prefix(self.key_prefix, self.name, deploy_image)
190-
bucket = self.bucket or self.sagemaker_session.default_bucket()
191-
print(f"bucket to upload code to is {bucket}")
192-
# self.uploaded_code = fw_utils.tar_and_upload_dir(
193-
# session=self.sagemaker_session.boto_session,
194-
# bucket=bucket,
195-
# s3_key_prefix=deploy_key_prefix,
196-
# directory=tmp_dir,
197-
# script=None,
198-
# )
199-
200-
def _validate_and_write_serving_properties(self):
205+
local_download_dir = (
206+
None
207+
if self.sagemaker_session.settings is None
208+
or self.sagemaker_session.settings.local_download_dir is None
209+
else self.sagemaker_session.settings.local_download_dir
210+
)
211+
with _tmpdir(directory=local_download_dir) as tmp:
212+
# Check to see if we need to bundle user provided code with serving.properties and re upload
213+
if self.source_dir or self.entry_point:
214+
_create_or_update_code_dir(
215+
tmp,
216+
self.entry_point,
217+
self.source_dir,
218+
self.dependencies,
219+
self.sagemaker_session,
220+
tmp,
221+
)
222+
existing_serving_properties = _read_existing_serving_properties(tmp)
223+
provided_serving_properties = self._generate_serving_properties()
224+
# provided kwargs take precedence over existing serving.properties file
225+
existing_serving_properties.update(provided_serving_properties)
226+
# self._validate_serving_properties(existing_serving_properties)
227+
228+
with open(os.path.join(tmp, "serving.properties"), "w+") as f:
229+
for key, val in existing_serving_properties.items():
230+
f.write(f"{key}={val}\n")
231+
232+
deploy_key_prefix = fw_utils.model_code_key_prefix(
233+
self.key_prefix, self.name, deploy_image
234+
)
235+
bucket = self.bucket or self.sagemaker_session.default_bucket()
236+
print(f"bucket to upload code to is {bucket}")
237+
uploaded_code = fw_utils.tar_and_upload_dir(
238+
self.sagemaker_session.boto_session,
239+
bucket,
240+
deploy_key_prefix,
241+
self.entry_point,
242+
directory=tmp,
243+
dependencies=self.dependencies,
244+
kms_key=self.model_kms_key,
245+
)
246+
return sagemaker.container_def(
247+
deploy_image, model_data_url=uploaded_code.s3_prefix, env=self.env
248+
)
249+
250+
def _generate_serving_properties(self):
201251
serving_properties = {
202252
"engine": "DeepSpeed",
203253
"option.entryPoint": "djl_python.deepspeed",
@@ -209,7 +259,9 @@ def _validate_and_write_serving_properties(self):
209259
serving_properties["option.low_cpu_mem_usage"] = self.low_cpu_mem_usage
210260
if self.enable_cuda_graph:
211261
if self.tensor_parallel_degree > 1:
212-
raise ValueError("enable_cuda_graph is not supported when tensor_parallel_degree > 1")
262+
raise ValueError(
263+
"enable_cuda_graph is not supported when tensor_parallel_degree > 1"
264+
)
213265
serving_properties["option.enable_cuda_graph"] = self.enable_cuda_graph
214266
if self.triangular_masking:
215267
serving_properties["option.triangular_masking"] = self.triangular_masking
@@ -238,20 +290,10 @@ def _validate_and_write_serving_properties(self):
238290
if self.prediction_timeout:
239291
serving_properties["option.prediction_timeout"] = self.prediction_timeout
240292

241-
local_dir = None if self.sagemaker_session.settings else self.sagemaker_session.settings.local_download_dir
242-
tmp_dir = tempfile.mkdtemp(dir=local_dir)
243-
244-
with open(os.path.join(tmp_dir, "serving.properties"), 'w+') as f:
245-
for key, value in serving_properties.items():
246-
f.write(f"{key}={value}\n")
247-
248-
print(f"wrote serving.properties to {tmp_dir}")
249-
250-
return tmp_dir
293+
return serving_properties
251294

252295

253296
class HuggingfaceAccelerateModel(DJLLargeModel):
254-
255297
_framework_name = "djl-deepspeed"
256298

257299
def __init__(
@@ -268,12 +310,12 @@ def __init__(
268310
entry_point: str = None,
269311
image_uri: Optional[Union[str, PipelineVariable]] = None,
270312
predictor_cls: callable = DJLLargeModelPredictor,
271-
**kwargs
313+
**kwargs,
272314
):
273315
self.device_id = device_id
274316
self.device_map = device_map
275-
self.load_in_8bit = load_in_8bit,
276-
self.low_cpu_mem_usage = low_cpu_mem_usage,
317+
self.load_in_8bit = (load_in_8bit,)
318+
self.low_cpu_mem_usage = (low_cpu_mem_usage,)
277319
super(HuggingfaceAccelerateModel, self).__init__(
278320
uncompressed_model_data,
279321
role=role,
@@ -283,24 +325,5 @@ def __init__(
283325
entry_point=entry_point,
284326
image_uri=image_uri,
285327
predictor_cls=predictor_cls,
286-
**kwargs
328+
**kwargs,
287329
)
288-
289-
if __name__ == "__main__":
290-
session = Session()
291-
role = "arn:aws:iam::125045733377:role/AmazonSageMaker-ExecutionRole-djl"
292-
opt_model = DJLLargeModel(
293-
"s3://dlc-deepspeed-test-temp/opt-2.7b/",
294-
tensor_parallel_degree=2,
295-
data_type="fp32",
296-
task="text-generation",
297-
max_tokens=2048,
298-
parallel_loading=True,
299-
role=role,
300-
sagemaker_session=session,
301-
)
302-
opt_model.prepare_container_def()
303-
# opt_model.deploy(
304-
# initial_instance_count=1,
305-
# instance_type="ml.g5.12xl"
306-
# )

0 commit comments

Comments
 (0)