Skip to content

Commit 30b1334

Browse files
committed
Add generic DJL Large model and engine specific model classes with constructors
1 parent e2f3888 commit 30b1334

File tree

3 files changed

+212
-0
lines changed

3 files changed

+212
-0
lines changed

src/sagemaker/djl_inference/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
DEEPSPEED_RECOMMENDED_ARCHITECTURES = {
17+
"bloom",
18+
"opt",
19+
"gpt_neox",
20+
"gptj",
21+
"gpt_neo",
22+
"gpt2",
23+
"xlm-roberta",
24+
"roberta",
25+
"bert",
26+
}
27+
28+
DEEPSPEED_SUPPORTED_ARCHITECTURES = {
29+
"bloom",
30+
"opt",
31+
"gpt_neox",
32+
"gptj",
33+
"gpt_neo",
34+
"gpt2",
35+
"xlm-roberta",
36+
"roberta",
37+
"bert",
38+
}

src/sagemaker/djl_inference/model.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
from __future__ import absolute_import
15+
16+
import json
17+
from enum import Enum
18+
from typing import Optional, Union, Dict
19+
20+
from sagemaker import s3, Predictor
21+
from sagemaker.deserializers import JSONDeserializer
22+
from sagemaker.model import FrameworkModel
23+
from sagemaker.serializers import JSONSerializer
24+
from sagemaker.session import Session
25+
import defaults
26+
from sagemaker.workflow.entities import PipelineVariable
27+
28+
29+
class DJLEngine(Enum):
30+
DEEPSPEED = "DeepSpeed"
31+
FASTER_TRANSFORMERS = "FasterTransformers"
32+
HUGGINGFACE_ACCELERATE = "Python"
33+
34+
35+
class DJLLargeModelPredictor(Predictor):
36+
37+
def __init__(
38+
self,
39+
endpoint_name,
40+
sagemaker_session=None,
41+
serializer=JSONSerializer(),
42+
deserializer=JSONDeserializer(),
43+
):
44+
super(DJLLargeModelPredictor, self).__init__(
45+
endpoint_name,
46+
sagemaker_session,
47+
serializer=serializer,
48+
deserializer=deserializer,
49+
)
50+
51+
52+
class DJLLargeModel(FrameworkModel):
53+
54+
def __new__(
55+
cls,
56+
uncompressed_model_data: str,
57+
*args,
58+
**kwargs,
59+
):
60+
if not uncompressed_model_data.startswith("s3://"):
61+
raise ValueError("DJLLargeModel only supports loading model artifacts from s3")
62+
if uncompressed_model_data.endswith("/"):
63+
config_file = uncompressed_model_data + "config.json"
64+
else:
65+
config_file = uncompressed_model_data + "/config.json"
66+
67+
model_type = json.loads(s3.S3Downloader.read_file(config_file)).get("model_type")
68+
cls_to_create = _determine_engine_for_model_type(model_type)
69+
return super(DJLLargeModel, cls).__new__(cls_to_create)
70+
71+
def __init__(
72+
self,
73+
uncompressed_model_data: str,
74+
task: str = None,
75+
data_type: str = None,
76+
tensor_parallel_degree: int = None,
77+
role: str = None,
78+
entry_point: Optional[str] = None,
79+
image_uri: Optional[Union[str, PipelineVariable]] = None,
80+
predictor_cls: callable = DJLLargeModelPredictor,
81+
**kwargs
82+
):
83+
self.uncompressed_model_data = uncompressed_model_data
84+
self.task = task
85+
self.data_type = data_type
86+
self.tensor_parallel_degree = tensor_parallel_degree,
87+
super(DJLLargeModel, self).__init__(
88+
None, image_uri, role, entry_point, predictor_cls=predictor_cls,**kwargs
89+
)
90+
self.sagemaker_session = self.sagemaker_session or Session()
91+
92+
def _determine_engine_for_model_type(model_type: str):
93+
if model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES:
94+
return DeepSpeedModel
95+
return HuggingfaceAccelerateModel
96+
97+
def _validate_engine_for_model_type(model_type: str, engine: DJLEngine):
98+
if engine == DJLEngine.DEEPSPEED:
99+
if model_type not in defaults.DEEPSPEED_SUPPORTED_ARCHITECTURES:
100+
raise ValueError(f"{model_type} is not supported by DeepSpeed. " \
101+
f"Supported model_types are {defaults.DEEPSPEED_SUPPORTED_ARCHITECTURES}")
102+
103+
class DeepSpeedModel(DJLLargeModel):
104+
105+
def __init__(
106+
self,
107+
uncompressed_model_data: str,
108+
max_tokens: int = None,
109+
low_cpu_mem_usage: bool = True,
110+
enable_cuda_graph: bool = False,
111+
triangular_masking: bool = True,
112+
return_tuple = True,
113+
deepspeed_checkpoint_file = None,
114+
task: str = None,
115+
data_type: str = None,
116+
tensor_parallel_degree: int = None,
117+
role: str = None,
118+
entry_point: Optional[str] = None,
119+
image_uri: Optional[Union[str, PipelineVariable]] = None,
120+
predictor_cls: callable = DJLLargeModelPredictor,
121+
**kwargs,
122+
):
123+
self.max_tokens = max_tokens
124+
self.low_cpu_mem_usage = low_cpu_mem_usage
125+
self.enable_cuda_graph = enable_cuda_graph
126+
self.triangular_masking = triangular_masking
127+
self.return_tuple = return_tuple
128+
self.deepspeed_checkpoint_file = deepspeed_checkpoint_file
129+
super(DeepSpeedModel, self).__init__(
130+
uncompressed_model_data,
131+
role=role,
132+
task=task,
133+
data_type=data_type,
134+
tensor_parallel_degree=tensor_parallel_degree,
135+
entry_point=entry_point,
136+
image_uri=image_uri,
137+
predictor_cls=predictor_cls,
138+
**kwargs,
139+
)
140+
141+
class HuggingfaceAccelerateModel(DJLLargeModel):
142+
143+
def __init__(
144+
self,
145+
uncompressed_model_data: str,
146+
device_id: int = None,
147+
device_map: Union[str, Dict[str, str]] = None,
148+
load_in_8bit: bool = False,
149+
low_cpu_mem_usage: bool = True,
150+
task: str = None,
151+
data_type: str = None,
152+
tensor_parallel_degree: int = None,
153+
role: str = None,
154+
entry_point: str = None,
155+
image_uri: Optional[Union[str, PipelineVariable]] = None,
156+
predictor_cls: callable = DJLLargeModelPredictor,
157+
**kwargs
158+
):
159+
self.device_id = device_id
160+
self.device_map = device_map
161+
self.load_in_8bit = load_in_8bit,
162+
self.low_cpu_mem_usage = low_cpu_mem_usage,
163+
super(HuggingfaceAccelerateModel, self).__init__(
164+
uncompressed_model_data,
165+
role=role,
166+
task=task,
167+
data_type=data_type,
168+
tensor_parallel_degree=tensor_parallel_degree,
169+
entry_point=entry_point,
170+
image_uri=image_uri,
171+
predictor_cls=predictor_cls,
172+
**kwargs
173+
)
174+

0 commit comments

Comments
 (0)