Skip to content

Commit 3c0a4f4

Browse files
committed
Generate serving.properties and upload code to s3
1 parent 30b1334 commit 3c0a4f4

File tree

3 files changed

+151
-3
lines changed

3 files changed

+151
-3
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.djl_inference.model import DJLLargeModel, DJLLargeModelPredictor, DeepSpeedModel, HuggingfaceAccelerateModel

src/sagemaker/djl_inference/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"bloom",
1818
"opt",
1919
"gpt_neox",
20-
"gptj",
20+
#"gptj",
2121
"gpt_neo",
2222
"gpt2",
2323
"xlm-roberta",

src/sagemaker/djl_inference/model.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
from __future__ import absolute_import
1515

1616
import json
17+
import os.path
18+
import tempfile
1719
from enum import Enum
1820
from typing import Optional, Union, Dict
1921

20-
from sagemaker import s3, Predictor
22+
import sagemaker
23+
from sagemaker import s3, Predictor, image_uris, fw_utils
2124
from sagemaker.deserializers import JSONDeserializer
2225
from sagemaker.model import FrameworkModel
2326
from sagemaker.serializers import JSONSerializer
@@ -71,24 +74,49 @@ def __new__(
7174
def __init__(
7275
self,
7376
uncompressed_model_data: str,
77+
djl_version: str = None,
7478
task: str = None,
7579
data_type: str = None,
7680
tensor_parallel_degree: int = None,
81+
min_workers: int = None,
82+
max_workers: int = None,
83+
job_queue_size: int = None,
84+
parallel_loading: bool = False,
85+
model_loading_timeout: int = None,
86+
prediction_timeout: int = None,
7787
role: str = None,
7888
entry_point: Optional[str] = None,
7989
image_uri: Optional[Union[str, PipelineVariable]] = None,
8090
predictor_cls: callable = DJLLargeModelPredictor,
8191
**kwargs
8292
):
8393
self.uncompressed_model_data = uncompressed_model_data
94+
self.djl_version = djl_version
8495
self.task = task
8596
self.data_type = data_type
86-
self.tensor_parallel_degree = tensor_parallel_degree,
97+
self.tensor_parallel_degree = tensor_parallel_degree
98+
self.min_workers = min_workers
99+
self.max_workers = max_workers
100+
self.job_queue_size = job_queue_size
101+
self.parallel_loading = parallel_loading
102+
self.model_loading_timeout = model_loading_timeout
103+
self.prediction_timeout = prediction_timeout
87104
super(DJLLargeModel, self).__init__(
88105
None, image_uri, role, entry_point, predictor_cls=predictor_cls,**kwargs
89106
)
90107
self.sagemaker_session = self.sagemaker_session or Session()
91108

109+
def serving_image_uri(self, region_name):
110+
if not self.djl_version:
111+
self.djl_version = "0.20.0"
112+
113+
return image_uris.retrieve(
114+
self._framework(),
115+
region_name,
116+
version=self.djl_version,
117+
)
118+
119+
92120
def _determine_engine_for_model_type(model_type: str):
93121
if model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES:
94122
return DeepSpeedModel
@@ -102,6 +130,8 @@ def _validate_engine_for_model_type(model_type: str, engine: DJLEngine):
102130

103131
class DeepSpeedModel(DJLLargeModel):
104132

133+
_framework_name = "djl-deepspeed"
134+
105135
def __init__(
106136
self,
107137
uncompressed_model_data: str,
@@ -138,8 +168,92 @@ def __init__(
138168
**kwargs,
139169
)
140170

171+
def prepare_container_def(
172+
self,
173+
instance_type=None,
174+
accelerator_type=None,
175+
serverless_inference_config=None,
176+
):
177+
if serverless_inference_config is not None:
178+
raise ValueError("DJLLargeModel does not support serverless deployment")
179+
if accelerator_type is not None:
180+
raise ValueError("DJLLargeModel does not support Elastic Inference accelerator")
181+
182+
deploy_image = self.image_uri
183+
if not deploy_image:
184+
region_name = self.sagemaker_session.boto_session.region_name
185+
deploy_image = self.serving_image_uri(region_name)
186+
187+
print(f"Deploy image is{deploy_image}")
188+
tmp_dir = self._validate_and_write_serving_properties()
189+
deploy_key_prefix = fw_utils.model_code_key_prefix(self.key_prefix, self.name, deploy_image)
190+
bucket = self.bucket or self.sagemaker_session.default_bucket()
191+
print(f"bucket to upload code to is {bucket}")
192+
# self.uploaded_code = fw_utils.tar_and_upload_dir(
193+
# session=self.sagemaker_session.boto_session,
194+
# bucket=bucket,
195+
# s3_key_prefix=deploy_key_prefix,
196+
# directory=tmp_dir,
197+
# script=None,
198+
# )
199+
200+
def _validate_and_write_serving_properties(self):
201+
serving_properties = {
202+
"engine": "DeepSpeed",
203+
"option.entryPoint": "djl_python.deepspeed",
204+
"option.s3url": self.uncompressed_model_data,
205+
}
206+
if self.max_tokens:
207+
serving_properties["option.max_tokens"] = self.max_tokens
208+
if self.low_cpu_mem_usage:
209+
serving_properties["option.low_cpu_mem_usage"] = self.low_cpu_mem_usage
210+
if self.enable_cuda_graph:
211+
if self.tensor_parallel_degree > 1:
212+
raise ValueError("enable_cuda_graph is not supported when tensor_parallel_degree > 1")
213+
serving_properties["option.enable_cuda_graph"] = self.enable_cuda_graph
214+
if self.triangular_masking:
215+
serving_properties["option.triangular_masking"] = self.triangular_masking
216+
if self.return_tuple:
217+
serving_properties["option.return_tuple"] = self.return_tuple
218+
if self.deepspeed_checkpoint_file:
219+
serving_properties["option.checkpoint"] = self.deepspeed_checkpoint_file
220+
if self.tensor_parallel_degree:
221+
serving_properties["option.tensor_parallel_degree"] = self.tensor_parallel_degree
222+
if self.entry_point:
223+
serving_properties["entryPoint"] = self.entry_point
224+
if self.task:
225+
serving_properties["option.task"] = self.task
226+
if self.data_type:
227+
serving_properties["option.dtype"] = self.data_type
228+
if self.min_workers:
229+
serving_properties["minWorkers"] = self.min_workers
230+
if self.max_workers:
231+
serving_properties["maxWorkers"] = self.max_workers
232+
if self.job_queue_size:
233+
serving_properties["job_queue_size"] = self.job_queue_size
234+
if self.parallel_loading:
235+
serving_properties["option.parallel_loading"] = self.parallel_loading
236+
if self.model_loading_timeout:
237+
serving_properties["option.model_loading_timeout"] = self.model_loading_timeout
238+
if self.prediction_timeout:
239+
serving_properties["option.prediction_timeout"] = self.prediction_timeout
240+
241+
local_dir = None if self.sagemaker_session.settings else self.sagemaker_session.settings.local_download_dir
242+
tmp_dir = tempfile.mkdtemp(dir=local_dir)
243+
244+
with open(os.path.join(tmp_dir, "serving.properties"), 'w+') as f:
245+
for key, value in serving_properties.items():
246+
f.write(f"{key}={value}\n")
247+
248+
print(f"wrote serving.properties to {tmp_dir}")
249+
250+
return tmp_dir
251+
252+
141253
class HuggingfaceAccelerateModel(DJLLargeModel):
142254

255+
_framework_name = "djl-deepspeed"
256+
143257
def __init__(
144258
self,
145259
uncompressed_model_data: str,
@@ -172,3 +286,21 @@ def __init__(
172286
**kwargs
173287
)
174288

289+
if __name__ == "__main__":
290+
session = Session()
291+
role = "arn:aws:iam::125045733377:role/AmazonSageMaker-ExecutionRole-djl"
292+
opt_model = DJLLargeModel(
293+
"s3://dlc-deepspeed-test-temp/opt-2.7b/",
294+
tensor_parallel_degree=2,
295+
data_type="fp32",
296+
task="text-generation",
297+
max_tokens=2048,
298+
parallel_loading=True,
299+
role=role,
300+
sagemaker_session=session,
301+
)
302+
opt_model.prepare_container_def()
303+
# opt_model.deploy(
304+
# initial_instance_count=1,
305+
# instance_type="ml.g5.12xl"
306+
# )

0 commit comments

Comments
 (0)