Skip to content

Commit e2a9b70

Browse files
committed
Generate container def for huggingface models
1 parent 35e6c47 commit e2a9b70

File tree

1 file changed

+147
-104
lines changed

1 file changed

+147
-104
lines changed

src/sagemaker/djl_inference/model.py

Lines changed: 147 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030

3131

3232
class DJLEngine(Enum):
33-
DEEPSPEED = "DeepSpeed"
34-
FASTER_TRANSFORMERS = "FasterTransformers"
35-
HUGGINGFACE_ACCELERATE = "Python"
33+
DEEPSPEED = ("DeepSpeed", "djl_python.deepspeed")
34+
FASTER_TRANSFORMERS = ("FasterTransformers", "djl_python.faster_transformers")
35+
HUGGINGFACE_ACCELERATE = ("Python", "djl_python.huggingface")
3636

3737

3838
class DJLLargeModelPredictor(Predictor):
@@ -93,6 +93,10 @@ def __init__(
9393
predictor_cls: callable = DJLLargeModelPredictor,
9494
**kwargs,
9595
):
96+
super(DJLLargeModel, self).__init__(
97+
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
98+
)
99+
self.engine = None
96100
self.uncompressed_model_data = uncompressed_model_data
97101
self.djl_version = djl_version
98102
self.task = task
@@ -104,11 +108,120 @@ def __init__(
104108
self.parallel_loading = parallel_loading
105109
self.model_loading_timeout = model_loading_timeout
106110
self.prediction_timeout = prediction_timeout
107-
super(DJLLargeModel, self).__init__(
108-
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
109-
)
110111
self.sagemaker_session = self.sagemaker_session or Session()
111112

113+
def package_for_edge(
114+
self,
115+
output_path,
116+
model_name,
117+
model_version,
118+
role=None,
119+
job_name=None,
120+
resource_key=None,
121+
s3_kms_key=None,
122+
tags=None,
123+
):
124+
raise NotImplementedError("DJLLargeModels do not support Sagemaker Edge")
125+
def compile(
126+
self,
127+
target_instance_family,
128+
input_shape,
129+
output_path,
130+
role,
131+
tags=None,
132+
job_name=None,
133+
compile_max_run=15 * 60,
134+
framework=None,
135+
framework_version=None,
136+
target_platform_os=None,
137+
target_platform_arch=None,
138+
target_platform_accelerator=None,
139+
compiler_options=None,
140+
):
141+
raise NotImplementedError("DJLLargeModels do not currently support compilation with SageMaker Neo")
142+
143+
def prepare_container_def(
144+
self,
145+
instance_type=None,
146+
accelerator_type=None,
147+
serverless_inference_config=None,
148+
):
149+
if serverless_inference_config is not None:
150+
raise ValueError("DJLLargeModel does not support serverless deployment")
151+
if accelerator_type is not None:
152+
raise ValueError("DJLLargeModel does not support Elastic Inference accelerator")
153+
154+
if not self.image_uri:
155+
region_name = self.sagemaker_session.boto_session.region_name
156+
self.image_uri = self.serving_image_uri(region_name)
157+
158+
local_download_dir = (
159+
None
160+
if self.sagemaker_session.settings is None
161+
or self.sagemaker_session.settings.local_download_dir is None
162+
else self.sagemaker_session.settings.local_download_dir
163+
)
164+
with _tmpdir(directory=local_download_dir) as tmp:
165+
if self.source_dir or self.entry_point:
166+
_create_or_update_code_dir(
167+
tmp,
168+
self.entry_point,
169+
self.source_dir,
170+
self.dependencies,
171+
self.sagemaker_session,
172+
tmp,
173+
)
174+
existing_serving_properties = _read_existing_serving_properties(tmp)
175+
kwargs_serving_properties = self.generate_serving_properties()
176+
existing_serving_properties.update(kwargs_serving_properties)
177+
178+
with open(os.path.join(tmp, "serving.properties"), "w+") as f:
179+
for key, val in existing_serving_properties.items():
180+
f.write(f"{key}={val}\n")
181+
182+
deploy_key_prefix = fw_utils.model_code_key_prefix(
183+
self.key_prefix, self.name, self.image_uri
184+
)
185+
bucket = self.bucket or self.sagemaker_session.default_bucket()
186+
uploaded_code = fw_utils.tar_and_upload_dir(
187+
self.sagemaker_session.boto_session,
188+
bucket,
189+
deploy_key_prefix,
190+
self.entry_point,
191+
directory=tmp,
192+
dependencies=self.dependencies,
193+
kms_key=self.model_kms_key,
194+
)
195+
return sagemaker.container_def(
196+
self.image_uri, model_data_url=uploaded_code.s3_prefix, env=self.env
197+
)
198+
199+
def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
200+
serving_properties["engine"] = self.engine[0]
201+
serving_properties["option.entryPoint"] = self.engine[1]
202+
serving_properties["option.s3url"] = self.uncompressed_model_data
203+
if self.tensor_parallel_degree:
204+
serving_properties["option.tensor_parallel_degree"] = self.tensor_parallel_degree
205+
if self.entry_point:
206+
serving_properties["entryPoint"] = self.entry_point
207+
if self.task:
208+
serving_properties["option.task"] = self.task
209+
if self.data_type:
210+
serving_properties["option.dtype"] = self.data_type
211+
if self.min_workers:
212+
serving_properties["minWorkers"] = self.min_workers
213+
if self.max_workers:
214+
serving_properties["maxWorkers"] = self.max_workers
215+
if self.job_queue_size:
216+
serving_properties["job_queue_size"] = self.job_queue_size
217+
if self.parallel_loading:
218+
serving_properties["option.parallel_loading"] = self.parallel_loading
219+
if self.model_loading_timeout:
220+
serving_properties["option.model_loading_timeout"] = self.model_loading_timeout
221+
if self.prediction_timeout:
222+
serving_properties["option.prediction_timeout"] = self.prediction_timeout
223+
return serving_properties
224+
112225
def serving_image_uri(self, region_name):
113226
if not self.djl_version:
114227
self.djl_version = "0.20.0"
@@ -167,12 +280,6 @@ def __init__(
167280
predictor_cls: callable = DJLLargeModelPredictor,
168281
**kwargs,
169282
):
170-
self.max_tokens = max_tokens
171-
self.low_cpu_mem_usage = low_cpu_mem_usage
172-
self.enable_cuda_graph = enable_cuda_graph
173-
self.triangular_masking = triangular_masking
174-
self.return_tuple = return_tuple
175-
self.deepspeed_checkpoint_file = deepspeed_checkpoint_file
176283
super(DeepSpeedModel, self).__init__(
177284
uncompressed_model_data,
178285
role=role,
@@ -184,75 +291,16 @@ def __init__(
184291
predictor_cls=predictor_cls,
185292
**kwargs,
186293
)
294+
self.engine = DJLEngine.DEEPSPEED
295+
self.max_tokens = max_tokens
296+
self.low_cpu_mem_usage = low_cpu_mem_usage
297+
self.enable_cuda_graph = enable_cuda_graph
298+
self.triangular_masking = triangular_masking
299+
self.return_tuple = return_tuple
300+
self.deepspeed_checkpoint_file = deepspeed_checkpoint_file
187301

188-
def prepare_container_def(
189-
self,
190-
instance_type=None,
191-
accelerator_type=None,
192-
serverless_inference_config=None,
193-
):
194-
if serverless_inference_config is not None:
195-
raise ValueError("DJLLargeModel does not support serverless deployment")
196-
if accelerator_type is not None:
197-
raise ValueError("DJLLargeModel does not support Elastic Inference accelerator")
198-
199-
deploy_image = self.image_uri
200-
if not deploy_image:
201-
region_name = self.sagemaker_session.boto_session.region_name
202-
deploy_image = self.serving_image_uri(region_name)
203-
204-
print(f"Deploy image is{deploy_image}")
205-
local_download_dir = (
206-
None
207-
if self.sagemaker_session.settings is None
208-
or self.sagemaker_session.settings.local_download_dir is None
209-
else self.sagemaker_session.settings.local_download_dir
210-
)
211-
with _tmpdir(directory=local_download_dir) as tmp:
212-
# Check to see if we need to bundle user provided code with serving.properties and re upload
213-
if self.source_dir or self.entry_point:
214-
_create_or_update_code_dir(
215-
tmp,
216-
self.entry_point,
217-
self.source_dir,
218-
self.dependencies,
219-
self.sagemaker_session,
220-
tmp,
221-
)
222-
existing_serving_properties = _read_existing_serving_properties(tmp)
223-
provided_serving_properties = self._generate_serving_properties()
224-
# provided kwargs take precedence over existing serving.properties file
225-
existing_serving_properties.update(provided_serving_properties)
226-
# self._validate_serving_properties(existing_serving_properties)
227-
228-
with open(os.path.join(tmp, "serving.properties"), "w+") as f:
229-
for key, val in existing_serving_properties.items():
230-
f.write(f"{key}={val}\n")
231-
232-
deploy_key_prefix = fw_utils.model_code_key_prefix(
233-
self.key_prefix, self.name, deploy_image
234-
)
235-
bucket = self.bucket or self.sagemaker_session.default_bucket()
236-
print(f"bucket to upload code to is {bucket}")
237-
uploaded_code = fw_utils.tar_and_upload_dir(
238-
self.sagemaker_session.boto_session,
239-
bucket,
240-
deploy_key_prefix,
241-
self.entry_point,
242-
directory=tmp,
243-
dependencies=self.dependencies,
244-
kms_key=self.model_kms_key,
245-
)
246-
return sagemaker.container_def(
247-
deploy_image, model_data_url=uploaded_code.s3_prefix, env=self.env
248-
)
249-
250-
def _generate_serving_properties(self):
251-
serving_properties = {
252-
"engine": "DeepSpeed",
253-
"option.entryPoint": "djl_python.deepspeed",
254-
"option.s3url": self.uncompressed_model_data,
255-
}
302+
def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
303+
serving_properties = super(DeepSpeedModel, self).generate_serving_properties()
256304
if self.max_tokens:
257305
serving_properties["option.max_tokens"] = self.max_tokens
258306
if self.low_cpu_mem_usage:
@@ -269,26 +317,6 @@ def _generate_serving_properties(self):
269317
serving_properties["option.return_tuple"] = self.return_tuple
270318
if self.deepspeed_checkpoint_file:
271319
serving_properties["option.checkpoint"] = self.deepspeed_checkpoint_file
272-
if self.tensor_parallel_degree:
273-
serving_properties["option.tensor_parallel_degree"] = self.tensor_parallel_degree
274-
if self.entry_point:
275-
serving_properties["entryPoint"] = self.entry_point
276-
if self.task:
277-
serving_properties["option.task"] = self.task
278-
if self.data_type:
279-
serving_properties["option.dtype"] = self.data_type
280-
if self.min_workers:
281-
serving_properties["minWorkers"] = self.min_workers
282-
if self.max_workers:
283-
serving_properties["maxWorkers"] = self.max_workers
284-
if self.job_queue_size:
285-
serving_properties["job_queue_size"] = self.job_queue_size
286-
if self.parallel_loading:
287-
serving_properties["option.parallel_loading"] = self.parallel_loading
288-
if self.model_loading_timeout:
289-
serving_properties["option.model_loading_timeout"] = self.model_loading_timeout
290-
if self.prediction_timeout:
291-
serving_properties["option.prediction_timeout"] = self.prediction_timeout
292320

293321
return serving_properties
294322

@@ -312,10 +340,6 @@ def __init__(
312340
predictor_cls: callable = DJLLargeModelPredictor,
313341
**kwargs,
314342
):
315-
self.device_id = device_id
316-
self.device_map = device_map
317-
self.load_in_8bit = (load_in_8bit,)
318-
self.low_cpu_mem_usage = (low_cpu_mem_usage,)
319343
super(HuggingfaceAccelerateModel, self).__init__(
320344
uncompressed_model_data,
321345
role=role,
@@ -327,3 +351,22 @@ def __init__(
327351
predictor_cls=predictor_cls,
328352
**kwargs,
329353
)
354+
self.engine = DJLEngine.HUGGINGFACE_ACCELERATE
355+
self.device_id = device_id
356+
self.device_map = device_map
357+
self.load_in_8bit = (load_in_8bit,)
358+
self.low_cpu_mem_usage = (low_cpu_mem_usage,)
359+
360+
def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
361+
serving_properties = super(HuggingfaceAccelerateModel, self).generate_serving_properties()
362+
if self.device_id:
363+
serving_properties["option.device_id"] = self.device_id
364+
if self.device_map:
365+
serving_properties["option.device_map"] = self.device_map
366+
if self.load_in_8bit:
367+
if not self.data_type == "int8":
368+
raise ValueError("Set data_type='int8' to use load_in_8bit")
369+
serving_properties["option.load_in_8bit"] = self.load_in_8bit
370+
if self.low_cpu_mem_usage:
371+
serving_properties["option.low_cpu_mem_usage"] = self.low_cpu_mem_usage
372+
return serving_properties

0 commit comments

Comments
 (0)