Skip to content

Commit 201f63e

Browse files
fix: Add volume to partition djl_inference (#3950)
1 parent 6146a73 commit 201f63e

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

src/sagemaker/djl_inference/model.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _create_estimator(
196196
image_uri: str,
197197
role: str,
198198
sagemaker_session: Optional[Session],
199-
volume_size: int = 30,
199+
volume_size: int,
200200
vpc_config: Optional[
201201
Dict[
202202
str,
@@ -453,7 +453,9 @@ def partition(
453453
self,
454454
instance_type: str,
455455
s3_output_uri: str = None,
456+
s3_output_prefix: str = "aot-partitioned-checkpoints",
456457
job_name: Optional[str] = None,
458+
volume_size: int = 30,
457459
volume_kms_key: Optional[str] = None,
458460
output_kms_key: Optional[str] = None,
459461
use_spot_instances: bool = False,
@@ -469,8 +471,13 @@ def partition(
469471
artifacts and output files). If not specified, results are
470472
stored to a default bucket. If the bucket with the specific name
471473
does not exist, it will be created.
474+
s3_output_prefix (str): Name of the prefix where all the partitioned
475+
checkpoints to be uploaded. If not provided, the default value is
476+
aot-partitioned-checkpoints.
472477
job_name (str): Training job name. If not specified, a unique training job
473478
name will be created.
479+
volume_size (int): Size in GB of the storage volume to use for
480+
storing input and output data during training (default: 30).
474481
volume_kms_key (str): Optional. KMS key ID for encrypting EBS
475482
volume attached to the training instance (default: None).
476483
output_kms_key (str): Optional. KMS key ID for encrypting the
@@ -499,20 +506,19 @@ def partition(
499506
region_name = self.sagemaker_session.boto_session.region_name
500507
self.image_uri = self.serving_image_uri(region_name)
501508

502-
deploy_key_prefix = fw_utils.model_code_key_prefix(
503-
self.key_prefix, self.name, self.image_uri
504-
)
505509
if s3_output_uri is None:
510+
deploy_key_prefix = fw_utils.model_code_key_prefix(
511+
self.key_prefix, self.name, self.image_uri
512+
)
513+
506514
bucket, deploy_key_prefix = s3.determine_bucket_and_prefix(
507515
bucket=self.bucket,
508516
key_prefix=deploy_key_prefix,
509517
sagemaker_session=self.sagemaker_session,
510518
)
511519
s3_output_uri = s3_path_join("s3://", bucket, deploy_key_prefix)
512-
else:
513-
s3_output_uri = s3_path_join(s3_output_uri, deploy_key_prefix)
514520

515-
self.save_mp_checkpoint_path = s3_path_join(s3_output_uri, "aot-partitioned-checkpoints")
521+
self.save_mp_checkpoint_path = s3_path_join(s3_output_uri, s3_output_prefix)
516522

517523
container_def = self._upload_model_to_s3(upload_as_tar=False)
518524
estimator = _create_estimator(
@@ -521,6 +527,7 @@ def partition(
521527
image_uri=self.image_uri,
522528
role=self.role,
523529
sagemaker_session=self.sagemaker_session,
530+
volume_size=volume_size,
524531
vpc_config=self.vpc_config,
525532
volume_kms_key=volume_kms_key,
526533
output_kms_key=output_kms_key,
@@ -924,7 +931,9 @@ def partition(
924931
self,
925932
instance_type: str,
926933
s3_output_uri: str = None,
934+
s3_output_prefix: str = "aot-partitioned-checkpoints",
927935
job_name: Optional[str] = None,
936+
volume_size: int = 30,
928937
volume_kms_key: Optional[str] = None,
929938
output_kms_key: Optional[str] = None,
930939
use_spot_instances: bool = False,
@@ -940,8 +949,13 @@ def partition(
940949
artifacts and output files). If not specified, results are
941950
stored to a default bucket. If the bucket with the specific name
942951
does not exist, it will be created.
952+
s3_output_prefix (str): Name of the prefix where all the partitioned
953+
checkpoints to be uploaded. If not provided, the default value is
954+
aot-partitioned-checkpoints.
943955
job_name (str): Training job name. If not specified, a unique training job
944956
name will be created.
957+
volume_size (int): Size in GB of the storage volume to use for
958+
storing input and output data during training (default: 30).
945959
volume_kms_key (str): Optional. KMS key ID for encrypting EBS
946960
volume attached to the training instance (default: None).
947961
output_kms_key (str): Optional. KMS key ID for encrypting the
@@ -969,7 +983,9 @@ def partition(
969983
super(DeepSpeedModel, self).partition(
970984
instance_type,
971985
s3_output_uri,
972-
job_name,
986+
s3_output_prefix=s3_output_prefix,
987+
job_name=job_name,
988+
volume_size=volume_size,
973989
volume_kms_key=volume_kms_key,
974990
output_kms_key=output_kms_key,
975991
use_spot_instances=use_spot_instances,
@@ -1096,7 +1112,9 @@ def partition(
10961112
self,
10971113
instance_type: str,
10981114
s3_output_uri: str = None,
1115+
s3_output_prefix: str = "aot-partitioned-checkpoints",
10991116
job_name: Optional[str] = None,
1117+
volume_size: int = 30,
11001118
volume_kms_key: Optional[str] = None,
11011119
output_kms_key: Optional[str] = None,
11021120
use_spot_instances: bool = False,
@@ -1112,8 +1130,13 @@ def partition(
11121130
artifacts and output files). If not specified, results are
11131131
stored to a default bucket. If the bucket with the specific name
11141132
does not exist, it will be created.
1133+
s3_output_prefix (str): Name of the prefix where all the partitioned
1134+
checkpoints to be uploaded. If not provided, the default value is
1135+
aot-partitioned-checkpoints.
11151136
job_name (str): Training job name. If not specified, a unique training job
11161137
name will be created.
1138+
volume_size (int): Size in GB of the storage volume to use for
1139+
storing input and output data during training (default: 30).
11171140
volume_kms_key (str): Optional. KMS key ID for encrypting EBS
11181141
volume attached to the training instance (default: None).
11191142
output_kms_key (str): Optional. KMS key ID for encrypting the

tests/unit/test_djl_inference.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def test_partition(
598598
IMAGE_URI, model_data_url="s3prefix", env=expected_env
599599
)
600600

601-
assert model.model_id == f"{s3_output_uri}s3prefix/aot-partitioned-checkpoints"
601+
assert model.model_id == f"{s3_output_uri}aot-partitioned-checkpoints"
602602

603603

604604
@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix")
@@ -741,15 +741,15 @@ def without_user_input(sess):
741741
"s3://code-test-bucket/code-test-prefix/code-test-prefix-2",
742742
"s3://code-test-bucket/code-test-prefix/code-test-prefix-2/image_uri",
743743
"s3://code-test-bucket/code-test-prefix/code-test-prefix-2/image_uri",
744-
"s3://test-bucket/test-prefix/test-prefix-2/code-test-prefix/code-test-prefix-2/image_uri",
745-
"s3://test-bucket/test-prefix/test-prefix-2/code-test-prefix/code-test-prefix-2/image_uri",
744+
"s3://test-bucket/test-prefix/test-prefix-2",
745+
"s3://test-bucket/test-prefix/test-prefix-2",
746746
),
747747
(
748748
None,
749749
f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri",
750750
f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri",
751-
"s3://test-bucket/test-prefix/test-prefix-2/image_uri",
752-
"s3://test-bucket/test-prefix/test-prefix-2/image_uri",
751+
"s3://test-bucket/test-prefix/test-prefix-2",
752+
"s3://test-bucket/test-prefix/test-prefix-2",
753753
),
754754
],
755755
)

0 commit comments

Comments
 (0)