Skip to content

Commit 3b1aa8a

Browse files
committed
change: add method to Model class to check if repack is needed
1 parent bfc63d2 commit 3b1aa8a

File tree

3 files changed

+126
-8
lines changed

3 files changed

+126
-8
lines changed

src/sagemaker/model.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -623,14 +623,7 @@ def prepare_container_def(
623623
)
624624
deploy_env = copy.deepcopy(self.env)
625625
if self.source_dir or self.dependencies or self.entry_point or self.git_config:
626-
is_repack = (
627-
self.source_dir
628-
and self.entry_point
629-
and not (
630-
(self.key_prefix and issubclass(type(self), FrameworkModel)) or self.git_config
631-
)
632-
)
633-
self._upload_code(deploy_key_prefix, repack=is_repack)
626+
self._upload_code(deploy_key_prefix, repack=self.is_repack())
634627
deploy_env.update(self._script_mode_env_vars())
635628

636629
return sagemaker.container_def(
@@ -640,6 +633,15 @@ def prepare_container_def(
640633
image_config=self.image_config,
641634
)
642635

636+
def is_repack(self) -> bool:
637+
"""Helper function to check if the source code need to be
638+
repacked before uploading to S3.
639+
640+
Returns:
641+
bool: if the source need to be repacked or not
642+
"""
643+
return self.source_dir and self.entry_point and not self.git_config
644+
643645
def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
644646
"""Uploads code to S3 to be used with script mode with SageMaker inference.
645647
@@ -1773,6 +1775,15 @@ def __init__(
17731775
**kwargs,
17741776
)
17751777

1778+
def is_repack(self) -> bool:
1779+
"""Helper function to check if the source code need to be
1780+
repacked before uploading to S3.
1781+
1782+
Returns:
1783+
bool: if the source need to be repacked or not
1784+
"""
1785+
return self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
1786+
17761787

17771788
# works for MODEL_PACKAGE_ARN with or without version info.
17781789
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"

tests/unit/sagemaker/model/test_framework_model.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
CODECOMMIT_REPO_SSH = "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/"
4949
CODECOMMIT_BRANCH = "master"
5050
REPO_DIR = "/tmp/repo_dir"
51+
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
5152

5253

5354
class DummyFrameworkModel(FrameworkModel):
@@ -471,3 +472,66 @@ def test_git_support_codecommit_ssh_passphrase_required(
471472
)
472473
model.prepare_container_def(instance_type=INSTANCE_TYPE)
473474
assert "returned non-zero exit status" in str(error.value)
475+
476+
477+
@patch("sagemaker.utils.repack_model")
478+
def test_not_repack_code_location_with_key_prefix(repack_model, sagemaker_session):
479+
480+
code_location = "s3://my-bucket/code/location/"
481+
482+
t = FrameworkModel(
483+
entry_point=ENTRY_POINT,
484+
role=ROLE,
485+
sagemaker_session=sagemaker_session,
486+
source_dir="s3://codebucket/someprefix/sourcedir.tar.gz",
487+
image_uri=IMAGE_URI,
488+
model_data=MODEL_DATA,
489+
code_location=code_location,
490+
)
491+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
492+
493+
repack_model.assert_not_called()
494+
495+
496+
@patch("sagemaker.utils.repack_model")
497+
def test_is_repack_with_code_location(repack_model, sagemaker_session):
498+
499+
code_location = "s3://my-bucket/code/location/"
500+
501+
model = FrameworkModel(
502+
entry_point=ENTRY_POINT,
503+
role=ROLE,
504+
sagemaker_session=sagemaker_session,
505+
source_dir="s3://codebucket/someprefix/sourcedir.tar.gz",
506+
image_uri=IMAGE_URI,
507+
model_data=MODEL_DATA,
508+
code_location=code_location,
509+
)
510+
511+
assert not model.is_repack()
512+
513+
514+
@patch("sagemaker.git_utils.git_clone_repo")
515+
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
516+
def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session):
517+
git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: {
518+
"entry_point": "entry_point",
519+
"source_dir": "/tmp/repo_dir/source_dir",
520+
"dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"],
521+
}
522+
523+
entry_point = "entry_point"
524+
source_dir = "source_dir"
525+
dependencies = ["foo", "bar"]
526+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
527+
model = FrameworkModel(
528+
sagemaker_session=sagemaker_session,
529+
entry_point=entry_point,
530+
source_dir=source_dir,
531+
dependencies=dependencies,
532+
git_config=git_config,
533+
image_uri=IMAGE_URI,
534+
model_data=MODEL_DATA,
535+
)
536+
537+
assert not model.is_repack()

tests/unit/sagemaker/model/test_model.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,49 @@ def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session):
704704
repack_model.assert_called_once()
705705

706706

707+
@patch("sagemaker.utils.repack_model")
708+
def test_is_repack_with_code_location(repack_model, sagemaker_session):
709+
710+
code_location = "s3://my-bucket/code/location/"
711+
712+
model = Model(
713+
entry_point=ENTRY_POINT_INFERENCE,
714+
role=ROLE,
715+
sagemaker_session=sagemaker_session,
716+
source_dir=SCRIPT_URI,
717+
image_uri=IMAGE_URI,
718+
model_data=MODEL_DATA,
719+
code_location=code_location,
720+
)
721+
722+
assert model.is_repack()
723+
724+
725+
@patch("sagemaker.git_utils.git_clone_repo")
726+
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
727+
def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session):
728+
git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: {
729+
"entry_point": "entry_point",
730+
"source_dir": "/tmp/repo_dir/source_dir",
731+
"dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"],
732+
}
733+
734+
entry_point = "entry_point"
735+
source_dir = "source_dir"
736+
dependencies = ["foo", "bar"]
737+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
738+
model = Model(
739+
sagemaker_session=sagemaker_session,
740+
entry_point=entry_point,
741+
source_dir=source_dir,
742+
dependencies=dependencies,
743+
git_config=git_config,
744+
image_uri=IMAGE_URI,
745+
)
746+
747+
assert not model.is_repack()
748+
749+
707750
@patch("sagemaker.utils.repack_model")
708751
@patch("sagemaker.fw_utils.tar_and_upload_dir")
709752
def test_all_framework_models_add_jumpstart_base_name(

0 commit comments

Comments
 (0)