Skip to content

Commit ae0177b

Browse files
committed
Refactor api to make parallelism options and corresponding engines more clear
1 parent 51b989f commit ae0177b

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

src/sagemaker/djl_inference/model.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,12 @@ def __init__(
8888
)
8989

9090

91-
def _determine_engine_for_model_type(model_type: str):
91+
def _determine_engine_for_model(model_type: str, tensor_parallel_degree: int):
9292
"""Placeholder docstring"""
9393

94+
if tensor_parallel_degree and tensor_parallel_degree > 1:
95+
return DeepSpeedModel
96+
9497
if model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES:
9598
return DeepSpeedModel
9699
return HuggingFaceAccelerateModel
@@ -126,6 +129,7 @@ class DJLLargeModel(FrameworkModel):
126129
def __new__(
127130
cls,
128131
uncompressed_model_data: str,
132+
tensor_parallel_degree: int = None,
129133
*args,
130134
**kwargs,
131135
):
@@ -143,14 +147,16 @@ def __new__(
143147

144148
model_type = json.loads(s3.S3Downloader.read_file(config_file)).get("model_type")
145149
cls_to_create = (
146-
cls if cls is not DJLLargeModel else _determine_engine_for_model_type(model_type)
150+
cls
151+
if cls is not DJLLargeModel
152+
else _determine_engine_for_model(model_type, tensor_parallel_degree)
147153
)
148154
return super(DJLLargeModel, cls).__new__(cls_to_create)
149155

150156
def __init__(
151157
self,
152158
uncompressed_model_data: str,
153-
role: str,
159+
role: str = None,
154160
djl_version: str = None,
155161
task: str = None,
156162
data_type: str = "fp32",
@@ -190,9 +196,10 @@ def __init__(
190196
than or equal to the number of gpus available on the instance. Defaults to None.
191197
If not provided, no tensor parallel sharding is done. If the provided value is
192198
greater than 1, DeepSpeed will be used as the backend.
193-
data_parallel_degree (int): The number of copies of the model to instantiate. It should be
199+
data_parallel_degree (int): The number of replicas of the model to instantiate. It should be
194200
less than or equal to the number of gpus available on the instance. Defaults to None.
195-
If not provided, all available gpus will be used.
201+
If not provided, all available gpus will be used. If tensor_parallel_degree is set,
202+
data_parallel_degree will be computed by DJL Serving based on the number of available GPUs.
196203
min_workers (int): The minimum number of worker processes. DJL Serving will auto detect
197204
the minimum workers if not specified. Defaults to None.
198205
max_workers (int): The maximum number of worker processes. DJL Serving will auto detect
@@ -536,8 +543,6 @@ def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
536543
serving_properties["engine"] = self.engine.value[0]
537544
serving_properties["option.entryPoint"] = self.engine.value[1]
538545
serving_properties["option.s3url"] = self.uncompressed_model_data
539-
if self.tensor_parallel_degree:
540-
serving_properties["option.tensor_parallel_degree"] = self.tensor_parallel_degree
541546
if self.entry_point:
542547
serving_properties["option.entryPoint"] = self.entry_point
543548
if self.task:
@@ -667,6 +672,8 @@ def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
667672
dict: The model server configuration to use when deploying this model to SageMaker.
668673
"""
669674
serving_properties = super(DeepSpeedModel, self).generate_serving_properties()
675+
if self.tensor_parallel_degree:
676+
serving_properties["option.tensor_parallel_degree"] = self.tensor_parallel_degree
670677
if self.max_tokens:
671678
serving_properties["option.max_tokens"] = self.max_tokens
672679
if self.low_cpu_mem_usage:
@@ -681,8 +688,6 @@ def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
681688
serving_properties["option.triangular_masking"] = self.triangular_masking
682689
if self.return_tuple:
683690
serving_properties["option.return_tuple"] = self.return_tuple
684-
if self.deepspeed_checkpoint_file:
685-
serving_properties["option.checkpoint"] = self.deepspeed_checkpoint_file
686691

687692
return serving_properties
688693

@@ -759,7 +764,11 @@ def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
759764
dict: The model server configuration to use when deploying this model to SageMaker.
760765
"""
761766
serving_properties = super(HuggingFaceAccelerateModel, self).generate_serving_properties()
767+
if self.data_parallel_degree:
768+
serving_properties["option.tensor_parallel_degree"] = self.data_parallel_degree
762769
if self.device_id:
770+
if self.data_parallel_degree > 1:
771+
raise ValueError("device_id cannot be set when data_parallel_degree is > 1")
763772
serving_properties["option.device_id"] = self.device_id
764773
if self.device_map:
765774
serving_properties["option.device_map"] = self.device_map

0 commit comments

Comments
 (0)