Skip to content

Commit cea9cc7

Browse files
committed
feature: Partition support for DJLModel using SM Training job
1 parent 46215ae commit cea9cc7

File tree

3 files changed

+307
-23
lines changed

3 files changed

+307
-23
lines changed

doc/frameworks/djl/using_djl.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,31 @@ see the `DJL Serving Documentation on Python Mode. <https://docs.djl.ai/docs/ser
209209

210210
For more information about DJL Serving, see the `DJL Serving documentation. <https://docs.djl.ai/docs/serving/index.html>`_
211211

212+
**************************
213+
Ahead of time partitioning
214+
**************************
215+
216+
To optimize the deployment of large models that do not fit in a single GPU, the model’s tensor weights are partitioned at
217+
runtime and each partition is loaded in individual GPU. But runtime partitioning takes significant amount of time and
218+
memory on model loading. So, DJLModel offers an ahead of time partitioning capability for DeepSpeed and FasterTransformer
219+
engines, which lets you partition your model weights and save them before deployment. HuggingFace does not support
220+
tensor parallelism, so ahead of time partitioning cannot be done for it. In our experiment with GPT-J model, loading
221+
this model with partitioned checkpoints increased the model loading time by 40%.
222+
223+
`partition` method invokes an Amazon SageMaker Training job to partition the model and upload those partitioned
224+
checkpoints to S3 bucket. You can either provide your desired S3 bucket to upload the partitioned checkpoints or it will be
225+
uploaded to the default SageMaker S3 bucket. Please note that this S3 bucket will be remembered for deployment. When you
226+
call `deploy` method after partition, DJLServing downloads the partitioned model checkpoints directly from the uploaded
227+
s3 url, if available.
228+
229+
.. code::
230+
231+
# partitions the model using Amazon Sagemaker Training Job.
232+
djl_model.partition("ml.g5.12xlarge")
233+
234+
predictor = deepspeed_model.deploy("ml.g5.12xlarge",
235+
initial_instance_count=1)
236+
212237
***********************
213238
SageMaker DJL Classes
214239
***********************

src/sagemaker/djl_inference/model.py

Lines changed: 208 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from json import JSONDecodeError
2121
from urllib.error import HTTPError, URLError
2222
from enum import Enum
23-
from typing import Optional, Union, Dict, Any
23+
from typing import Optional, Union, Dict, Any, List
2424

2525
import sagemaker
2626
from sagemaker import s3, Predictor, image_uris, fw_utils
@@ -31,6 +31,8 @@
3131
from sagemaker.session import Session
3232
from sagemaker.utils import _tmpdir, _create_or_update_code_dir
3333
from sagemaker.workflow.entities import PipelineVariable
34+
from sagemaker.estimator import Estimator
35+
from sagemaker.s3 import S3Uploader
3436

3537
logger = logging.getLogger("sagemaker")
3638

@@ -180,6 +182,45 @@ def _get_model_config_properties_from_hf(model_id: str):
180182
return model_config
181183

182184

185+
def _create_estimator(instance_type: str,
186+
s3_output_uri: str,
187+
image_uri: str,
188+
role: str,
189+
sagemaker_session: Optional[Session],
190+
volume_size: int = 30,
191+
vpc_config: Optional[Dict[str, List[str, ]]] = None,
192+
volume_kms_key=None,
193+
output_kms_key=None,
194+
use_spot_instances: bool = False,
195+
max_wait: int = None,
196+
enable_network_isolation: bool = False,
197+
):
198+
"""Placeholder docstring"""
199+
200+
subnets = None
201+
security_group_ids = None
202+
if vpc_config:
203+
subnets = vpc_config.get("Subnets")
204+
security_group_ids = vpc_config.get("SecurityGroupIds")
205+
206+
return Estimator(
207+
image_uri=image_uri,
208+
role=role,
209+
instance_count=1,
210+
instance_type=instance_type,
211+
volume_size=volume_size,
212+
volume_kms_key=volume_kms_key,
213+
output_path=s3_output_uri,
214+
output_kms_key=output_kms_key,
215+
sagemaker_session=sagemaker_session,
216+
subnets=subnets,
217+
security_group_ids=security_group_ids,
218+
use_spot_instances=use_spot_instances,
219+
max_wait=max_wait,
220+
enable_network_isolation=enable_network_isolation,
221+
)
222+
223+
183224
class DJLModel(FrameworkModel):
184225
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
185226

@@ -381,6 +422,91 @@ def right_size(self, checkpoint_data_type: str):
381422
"DJLModels do not currently support Inference Recommendation Jobs"
382423
)
383424

425+
def partition(
426+
self,
427+
instance_type: str,
428+
s3_output_uri: str = None,
429+
job_name: Optional[str] = None,
430+
volume_kms_key: Optional[str] = None,
431+
output_kms_key: Optional[str] = None,
432+
use_spot_instances: bool = False,
433+
max_wait: int = None,
434+
enable_network_isolation: bool = False
435+
):
436+
"""Partitions the model using SageMaker Training Job.
437+
This is a synchronous API call.
438+
439+
Args:
440+
instance_type (str): The EC2 instance type to partition this Model.
441+
For example, 'ml.p4d.24xlarge'.
442+
s3_output_uri (str): S3 location for saving the training result (model
443+
artifacts and output files). If not specified, results are
444+
stored to a default bucket. If the bucket with the specific name
445+
does not exist, it will be created.
446+
job_name (str): Training job name. If not specified, a unique training job
447+
name will be created.
448+
volume_kms_key (str): Optional. KMS key ID for encrypting EBS
449+
volume attached to the training instance (default: None).
450+
output_kms_key (str): Optional. KMS key ID for encrypting the
451+
training output (default: None).
452+
use_spot_instances (bool): Specifies whether to use SageMaker
453+
Managed Spot instances for training. If enabled then the
454+
``max_wait`` arg should also be set.
455+
456+
More information:
457+
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
458+
(default: ``False``).
459+
max_wait (int): Timeout in seconds waiting for spot training
460+
job (default: None). After this amount of time Amazon
461+
SageMaker will stop waiting for managed spot training job to
462+
complete (default: None).
463+
enable_network_isolation (bool): Specifies whether container will
464+
run in network isolation mode (default: ``False``). Network
465+
isolation mode restricts the container access to outside networks
466+
(such as the Internet). The container does not make any inbound or
467+
outbound network calls. Also known as Internet-free mode.
468+
Returns:
469+
None
470+
"""
471+
472+
deploy_key_prefix = fw_utils.model_code_key_prefix(
473+
self.key_prefix, self.name, self.image_uri
474+
)
475+
if s3_output_uri is None:
476+
bucket = self.bucket or self.sagemaker_session.default_bucket()
477+
s3_output_uri = f"s3://{bucket}/{deploy_key_prefix}"
478+
else:
479+
s3_output_uri = f"{s3_output_uri}/{deploy_key_prefix}"
480+
481+
self.save_mp_checkpoint_path = f"{s3_output_uri}/aot-partitioned-checkpoints"
482+
483+
container_def = self._upload_model_to_s3(upload_as_tar=False)
484+
estimator = _create_estimator(instance_type=instance_type,
485+
s3_output_uri=s3_output_uri,
486+
image_uri=self.image_uri,
487+
role=self.role,
488+
sagemaker_session=self.sagemaker_session,
489+
vpc_config=self.vpc_config,
490+
volume_kms_key=volume_kms_key,
491+
output_kms_key=output_kms_key,
492+
use_spot_instances=use_spot_instances,
493+
max_wait=max_wait,
494+
enable_network_isolation=enable_network_isolation
495+
)
496+
497+
# creates a training job to do partitions
498+
estimator.fit(
499+
inputs=container_def["ModelDataUrl"],
500+
wait=True,
501+
logs="All",
502+
job_name=job_name,
503+
experiment_config=None,
504+
)
505+
506+
self.model_id = self.save_mp_checkpoint_path
507+
# reset save_mp_checkpoint_path since partition is completed.
508+
self.save_mp_checkpoint_path = None
509+
384510
def deploy(
385511
self,
386512
instance_type,
@@ -477,18 +603,10 @@ def deploy(
477603
container_startup_health_check_timeout=container_startup_health_check_timeout,
478604
)
479605

480-
def prepare_container_def(
481-
self,
482-
instance_type=None,
483-
accelerator_type=None,
484-
serverless_inference_config=None,
485-
): # pylint: disable=unused-argument
486-
"""A container definition with framework configuration set in model environment variables.
487-
488-
Returns:
489-
dict[str, str]: A container definition object usable with the
490-
CreateModel API.
491-
"""
606+
def _upload_model_to_s3(self,
607+
upload_as_tar: bool = True
608+
):
609+
"""Placeholder docstring"""
492610

493611
if not self.image_uri:
494612
region_name = self.sagemaker_session.boto_session.region_name
@@ -528,19 +646,42 @@ def prepare_container_def(
528646
self.key_prefix, self.name, self.image_uri
529647
)
530648
bucket = self.bucket or self.sagemaker_session.default_bucket()
531-
uploaded_code = fw_utils.tar_and_upload_dir(
532-
self.sagemaker_session.boto_session,
533-
bucket,
534-
deploy_key_prefix,
535-
self.entry_point,
536-
directory=tmp_code_dir,
537-
dependencies=self.dependencies,
538-
kms_key=self.model_kms_key,
539-
)
649+
if upload_as_tar:
650+
uploaded_code = fw_utils.tar_and_upload_dir(
651+
self.sagemaker_session.boto_session,
652+
bucket,
653+
deploy_key_prefix,
654+
self.entry_point,
655+
directory=tmp_code_dir,
656+
dependencies=self.dependencies,
657+
kms_key=self.model_kms_key,
658+
)
659+
model_data_url = uploaded_code.s3_prefix
660+
else:
661+
key_prefix = f"{deploy_key_prefix}/aot-model"
662+
model_data_url = S3Uploader.upload(tmp_code_dir,
663+
"s3://%s/%s" % (bucket, key_prefix),
664+
self.model_kms_key,
665+
self.sagemaker_session)
540666
return sagemaker.container_def(
541-
self.image_uri, model_data_url=uploaded_code.s3_prefix, env=environment
667+
self.image_uri, model_data_url=model_data_url, env=environment
542668
)
543669

670+
def prepare_container_def(
671+
self,
672+
instance_type=None,
673+
accelerator_type=None,
674+
serverless_inference_config=None,
675+
): # pylint: disable=unused-argument
676+
"""A container definition with framework configuration set in model environment variables.
677+
678+
Returns:
679+
dict[str, str]: A container definition object usable with the
680+
CreateModel API.
681+
"""
682+
683+
return self._upload_model_to_s3(upload_as_tar=True)
684+
544685
def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]:
545686
"""Generates the DJL Serving configuration to use for the model.
546687
@@ -699,6 +840,8 @@ def __init__(
699840
self.enable_cuda_graph = enable_cuda_graph
700841
self.triangular_masking = triangular_masking
701842
self.return_tuple = return_tuple
843+
self.save_mp_checkpoint_path = None
844+
self.checkpoint = None
702845

703846
def generate_serving_properties(self, serving_properties=None) -> Dict[str, Any]:
704847
"""Generates the DJL Serving configuration to use for the model.
@@ -733,9 +876,35 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, Any]
733876
serving_properties["option.triangular_masking"] = self.triangular_masking
734877
if self.return_tuple:
735878
serving_properties["option.return_tuple"] = self.return_tuple
879+
if self.save_mp_checkpoint_path:
880+
serving_properties["option.save_mp_checkpoint_path"] = self.save_mp_checkpoint_path
881+
if self.checkpoint:
882+
serving_properties["option.checkpoint"] = self.checkpoint
736883

737884
return serving_properties
738885

886+
def partition(
887+
self,
888+
instance_type: str,
889+
s3_output_uri: str = None,
890+
job_name: Optional[str] = None,
891+
volume_kms_key: Optional[str] = None,
892+
output_kms_key: Optional[str] = None,
893+
use_spot_instances: bool = False,
894+
max_wait: int = None,
895+
enable_network_isolation: bool = False
896+
):
897+
super(DeepSpeedModel, self).partition(instance_type,
898+
s3_output_uri,
899+
job_name,
900+
volume_kms_key=volume_kms_key,
901+
output_kms_key=output_kms_key,
902+
use_spot_instances=use_spot_instances,
903+
max_wait=max_wait,
904+
enable_network_isolation=enable_network_isolation)
905+
906+
self.checkpoint = "ds_inference_config.json"
907+
739908

740909
class HuggingFaceAccelerateModel(DJLModel):
741910
"""A DJL Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
@@ -846,3 +1015,19 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
8461015
serving_properties["option.dtype"] = "auto"
8471016
serving_properties.pop("option.load_in_8bit", None)
8481017
return serving_properties
1018+
1019+
def partition(
1020+
self,
1021+
instance_type: str,
1022+
s3_output_uri: str = None,
1023+
job_name: Optional[str] = None,
1024+
volume_kms_key: Optional[str] = None,
1025+
output_kms_key: Optional[str] = None,
1026+
use_spot_instances: bool = False,
1027+
max_wait: int = None,
1028+
enable_network_isolation: bool = False
1029+
):
1030+
logger.warning(
1031+
"HuggingFace engine does not currently support tensor parallelism. "
1032+
"Hence ahead of time partitioning is skipped"
1033+
)

0 commit comments

Comments
 (0)