Skip to content

Commit 3542bb9

Browse files
committed
chore: add unit tests, fix unused import
1 parent c639b19 commit 3542bb9

File tree

4 files changed

+240
-3
lines changed

4 files changed

+240
-3
lines changed

tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2424
from sagemaker.predictor import Predictor
25-
from sagemaker.utils import name_from_base
2625
from tests.integ.sagemaker.jumpstart.constants import (
2726
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
2827
JUMPSTART_TAG,

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE,
2121
JUMPSTART_BUCKET_NAME_SET,
2222
JUMPSTART_REGION_NAME_SET,
23+
JUMPSTART_RESOURCE_BASE_NAME,
2324
JumpStartScriptScope,
2425
)
2526
from sagemaker.jumpstart.enums import JumpStartTag
@@ -874,3 +875,11 @@ def make_deprecated_spec(*largs, **kwargs):
874875
"pytorch-eqa-bert-base-cased",
875876
"*",
876877
)
878+
879+
880+
def test_get_jumpstart_base_name_if_jumpstart_model():
881+
uris = [random_jumpstart_s3_uri("random_key") for _ in range(random.randint(1, 10))]
882+
assert JUMPSTART_RESOURCE_BASE_NAME == utils.get_jumpstart_base_name_if_jumpstart_model(*uris)
883+
884+
uris = ["s3://not-jumpstart-bucket/some-key" for _ in range(random.randint(0, 10))]
885+
assert utils.get_jumpstart_base_name_if_jumpstart_model(*uris) is None

tests/unit/sagemaker/model/test_model.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import sagemaker
2020
from sagemaker.model import FrameworkModel, Model
2121
from sagemaker.huggingface.model import HuggingFaceModel
22-
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET
22+
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME
2323
from sagemaker.jumpstart.enums import JumpStartTag
2424
from sagemaker.mxnet.model import MXNetModel
2525
from sagemaker.pytorch.model import PyTorchModel
@@ -551,3 +551,93 @@ def test_all_framework_models_add_jumpstart_tags(
551551

552552
sagemaker_session.create_model.reset_mock()
553553
sagemaker_session.endpoint_from_production_variants.reset_mock()
554+
555+
556+
@patch("sagemaker.utils.repack_model")
557+
def test_script_mode_model_uses_jumpstart_base_name(repack_model, sagemaker_session):
558+
559+
jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
560+
t = Model(
561+
entry_point=ENTRY_POINT_INFERENCE,
562+
role=ROLE,
563+
sagemaker_session=sagemaker_session,
564+
source_dir=jumpstart_source_dir,
565+
image_uri=IMAGE_URI,
566+
model_data=MODEL_DATA,
567+
)
568+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
569+
570+
assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
571+
JUMPSTART_RESOURCE_BASE_NAME
572+
)
573+
574+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
575+
JUMPSTART_RESOURCE_BASE_NAME
576+
)
577+
578+
sagemaker_session.create_model.reset_mock()
579+
sagemaker_session.endpoint_from_production_variants.reset_mock()
580+
581+
non_jumpstart_source_dir = "s3://blah/blah/blah"
582+
t = Model(
583+
entry_point=ENTRY_POINT_INFERENCE,
584+
role=ROLE,
585+
sagemaker_session=sagemaker_session,
586+
source_dir=non_jumpstart_source_dir,
587+
image_uri=IMAGE_URI,
588+
model_data=MODEL_DATA,
589+
)
590+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
591+
592+
assert not sagemaker_session.create_model.call_args_list[0][0][0].startswith(
593+
JUMPSTART_RESOURCE_BASE_NAME
594+
)
595+
596+
assert not sagemaker_session.endpoint_from_production_variants.call_args_list[0][1][
597+
"name"
598+
].startswith(JUMPSTART_RESOURCE_BASE_NAME)
599+
600+
601+
@patch("sagemaker.utils.repack_model")
602+
@patch("sagemaker.fw_utils.tar_and_upload_dir")
603+
def test_all_framework_models_add_jumpstart_base_name(
604+
repack_model, tar_and_uload_dir, sagemaker_session
605+
):
606+
framework_model_classes_to_kwargs = {
607+
PyTorchModel: {"framework_version": "1.5.0", "py_version": "py3"},
608+
TensorFlowModel: {
609+
"framework_version": "2.3",
610+
},
611+
HuggingFaceModel: {
612+
"pytorch_version": "1.7.1",
613+
"py_version": "py36",
614+
"transformers_version": "4.6.1",
615+
},
616+
MXNetModel: {"framework_version": "1.7.0", "py_version": "py3"},
617+
SKLearnModel: {
618+
"framework_version": "0.23-1",
619+
},
620+
XGBoostModel: {
621+
"framework_version": "1.3-1",
622+
},
623+
}
624+
jumpstart_model_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz"
625+
for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
626+
framework_model_class(
627+
entry_point=ENTRY_POINT_INFERENCE,
628+
role=ROLE,
629+
sagemaker_session=sagemaker_session,
630+
model_data=jumpstart_model_dir,
631+
**kwargs,
632+
).deploy(instance_type="ml.m2.xlarge", initial_instance_count=INSTANCE_COUNT)
633+
634+
assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
635+
JUMPSTART_RESOURCE_BASE_NAME
636+
)
637+
638+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
639+
JUMPSTART_RESOURCE_BASE_NAME
640+
)
641+
642+
sagemaker_session.create_model.reset_mock()
643+
sagemaker_session.endpoint_from_production_variants.reset_mock()

tests/unit/test_estimator.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from botocore.exceptions import ClientError
2525
from mock import ANY, MagicMock, Mock, patch
2626
from sagemaker.huggingface.estimator import HuggingFace
27-
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET
27+
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME
2828
from sagemaker.jumpstart.enums import JumpStartTag
2929

3030
import sagemaker.local
@@ -3851,3 +3851,142 @@ def test_all_framework_estimators_add_jumpstart_tags(
38513851
]
38523852

38533853
sagemaker_session.train.reset_mock()
3854+
3855+
3856+
@patch("time.time", return_value=TIME)
3857+
@patch("sagemaker.estimator.tar_and_upload_dir")
3858+
@patch("sagemaker.model.Model._upload_code")
3859+
def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models(
3860+
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
3861+
):
3862+
patched_tar_and_upload_dir.return_value = UploadedCode(
3863+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
3864+
)
3865+
sagemaker_session.boto_region_name = REGION
3866+
3867+
instance_type = "ml.p2.xlarge"
3868+
instance_count = 1
3869+
3870+
training_data_uri = "s3://bucket/mydata"
3871+
3872+
source_dir = "s3://dsfsdfsd/sdfsdfs/sdfsd"
3873+
3874+
generic_estimator = Estimator(
3875+
entry_point=SCRIPT_PATH,
3876+
role=ROLE,
3877+
region=REGION,
3878+
sagemaker_session=sagemaker_session,
3879+
instance_count=instance_count,
3880+
instance_type=instance_type,
3881+
source_dir=source_dir,
3882+
image_uri=IMAGE_URI,
3883+
model_uri=MODEL_DATA,
3884+
)
3885+
generic_estimator.fit(training_data_uri)
3886+
3887+
assert not sagemaker_session.train.call_args_list[0][1]["job_name"].startswith(
3888+
JUMPSTART_RESOURCE_BASE_NAME
3889+
)
3890+
sagemaker_session.reset_mock()
3891+
sagemaker_session.sagemaker_client.describe_training_job.return_value = {
3892+
"ModelArtifacts": {"S3ModelArtifacts": "some-uri"}
3893+
}
3894+
3895+
inference_jumpstart_source_dir = (
3896+
f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/inference/source.tar.gz"
3897+
)
3898+
3899+
generic_estimator.deploy(
3900+
initial_instance_count=INSTANCE_COUNT,
3901+
instance_type=INSTANCE_TYPE,
3902+
image_uri=IMAGE_URI,
3903+
source_dir=inference_jumpstart_source_dir,
3904+
entry_point="inference.py",
3905+
role=ROLE,
3906+
)
3907+
3908+
assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
3909+
JUMPSTART_RESOURCE_BASE_NAME
3910+
)
3911+
3912+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
3913+
JUMPSTART_RESOURCE_BASE_NAME
3914+
)
3915+
3916+
3917+
@patch("time.time", return_value=TIME)
3918+
@patch("sagemaker.estimator.tar_and_upload_dir")
3919+
@patch("sagemaker.model.Model._upload_code")
3920+
@patch("sagemaker.utils.repack_model")
3921+
def test_all_framework_estimators_add_jumpstart_base_name(
3922+
patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
3923+
):
3924+
3925+
sagemaker_session.boto_region_name = REGION
3926+
sagemaker_session.sagemaker_client.describe_training_job.return_value = {
3927+
"ModelArtifacts": {"S3ModelArtifacts": "some-uri"}
3928+
}
3929+
3930+
patched_tar_and_upload_dir.return_value = UploadedCode(
3931+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
3932+
)
3933+
3934+
framework_estimator_classes_to_kwargs = {
3935+
PyTorch: {
3936+
"framework_version": "1.5.0",
3937+
"py_version": "py3",
3938+
"instance_type": "ml.p2.xlarge",
3939+
},
3940+
TensorFlow: {
3941+
"framework_version": "2.3",
3942+
"py_version": "py37",
3943+
"instance_type": "ml.p2.xlarge",
3944+
},
3945+
HuggingFace: {
3946+
"pytorch_version": "1.7.1",
3947+
"py_version": "py36",
3948+
"transformers_version": "4.6.1",
3949+
"instance_type": "ml.p2.xlarge",
3950+
},
3951+
MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"},
3952+
SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"},
3953+
XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"},
3954+
}
3955+
jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz"
3956+
jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz"
3957+
for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items():
3958+
estimator = framework_estimator_class(
3959+
entry_point=ENTRY_POINT,
3960+
role=ROLE,
3961+
sagemaker_session=sagemaker_session,
3962+
model_uri=jumpstart_model_uri,
3963+
instance_count=INSTANCE_COUNT,
3964+
**kwargs,
3965+
)
3966+
3967+
estimator.fit()
3968+
3969+
assert sagemaker_session.train.call_args_list[0][1]["job_name"].startswith(
3970+
JUMPSTART_RESOURCE_BASE_NAME
3971+
)
3972+
3973+
estimator.deploy(
3974+
initial_instance_count=INSTANCE_COUNT,
3975+
instance_type=kwargs["instance_type"],
3976+
image_uri=IMAGE_URI,
3977+
source_dir=jumpstart_model_uri_2,
3978+
entry_point="inference.py",
3979+
role=ROLE,
3980+
)
3981+
3982+
assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
3983+
JUMPSTART_RESOURCE_BASE_NAME
3984+
)
3985+
3986+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
3987+
JUMPSTART_RESOURCE_BASE_NAME
3988+
)
3989+
3990+
sagemaker_session.endpoint_from_production_variants.reset_mock()
3991+
sagemaker_session.create_model.reset_mock()
3992+
sagemaker_session.train.reset_mock()

0 commit comments

Comments
 (0)