|
24 | 24 | from botocore.exceptions import ClientError
|
25 | 25 | from mock import ANY, MagicMock, Mock, patch
|
26 | 26 | from sagemaker.huggingface.estimator import HuggingFace
|
27 |
| -from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET |
| 27 | +from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME |
28 | 28 | from sagemaker.jumpstart.enums import JumpStartTag
|
29 | 29 |
|
30 | 30 | import sagemaker.local
|
@@ -3851,3 +3851,142 @@ def test_all_framework_estimators_add_jumpstart_tags(
|
3851 | 3851 | ]
|
3852 | 3852 |
|
3853 | 3853 | sagemaker_session.train.reset_mock()
|
| 3854 | + |
| 3855 | + |
| 3856 | +@patch("time.time", return_value=TIME) |
| 3857 | +@patch("sagemaker.estimator.tar_and_upload_dir") |
| 3858 | +@patch("sagemaker.model.Model._upload_code") |
| 3859 | +def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models( |
| 3860 | + patched_upload_code, patched_tar_and_upload_dir, sagemaker_session |
| 3861 | +): |
| 3862 | + patched_tar_and_upload_dir.return_value = UploadedCode( |
| 3863 | + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" |
| 3864 | + ) |
| 3865 | + sagemaker_session.boto_region_name = REGION |
| 3866 | + |
| 3867 | + instance_type = "ml.p2.xlarge" |
| 3868 | + instance_count = 1 |
| 3869 | + |
| 3870 | + training_data_uri = "s3://bucket/mydata" |
| 3871 | + |
| 3872 | + source_dir = "s3://dsfsdfsd/sdfsdfs/sdfsd" |
| 3873 | + |
| 3874 | + generic_estimator = Estimator( |
| 3875 | + entry_point=SCRIPT_PATH, |
| 3876 | + role=ROLE, |
| 3877 | + region=REGION, |
| 3878 | + sagemaker_session=sagemaker_session, |
| 3879 | + instance_count=instance_count, |
| 3880 | + instance_type=instance_type, |
| 3881 | + source_dir=source_dir, |
| 3882 | + image_uri=IMAGE_URI, |
| 3883 | + model_uri=MODEL_DATA, |
| 3884 | + ) |
| 3885 | + generic_estimator.fit(training_data_uri) |
| 3886 | + |
| 3887 | + assert not sagemaker_session.train.call_args_list[0][1]["job_name"].startswith( |
| 3888 | + JUMPSTART_RESOURCE_BASE_NAME |
| 3889 | + ) |
| 3890 | + sagemaker_session.reset_mock() |
| 3891 | + sagemaker_session.sagemaker_client.describe_training_job.return_value = { |
| 3892 | + "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} |
| 3893 | + } |
| 3894 | + |
| 3895 | + inference_jumpstart_source_dir = ( |
| 3896 | + f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/inference/source.tar.gz" |
| 3897 | + ) |
| 3898 | + |
| 3899 | + generic_estimator.deploy( |
| 3900 | + initial_instance_count=INSTANCE_COUNT, |
| 3901 | + instance_type=INSTANCE_TYPE, |
| 3902 | + image_uri=IMAGE_URI, |
| 3903 | + source_dir=inference_jumpstart_source_dir, |
| 3904 | + entry_point="inference.py", |
| 3905 | + role=ROLE, |
| 3906 | + ) |
| 3907 | + |
| 3908 | + assert sagemaker_session.create_model.call_args_list[0][0][0].startswith( |
| 3909 | + JUMPSTART_RESOURCE_BASE_NAME |
| 3910 | + ) |
| 3911 | + |
| 3912 | + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith( |
| 3913 | + JUMPSTART_RESOURCE_BASE_NAME |
| 3914 | + ) |
| 3915 | + |
| 3916 | + |
| 3917 | +@patch("time.time", return_value=TIME) |
| 3918 | +@patch("sagemaker.estimator.tar_and_upload_dir") |
| 3919 | +@patch("sagemaker.model.Model._upload_code") |
| 3920 | +@patch("sagemaker.utils.repack_model") |
| 3921 | +def test_all_framework_estimators_add_jumpstart_base_name( |
| 3922 | + patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session |
| 3923 | +): |
| 3924 | + |
| 3925 | + sagemaker_session.boto_region_name = REGION |
| 3926 | + sagemaker_session.sagemaker_client.describe_training_job.return_value = { |
| 3927 | + "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} |
| 3928 | + } |
| 3929 | + |
| 3930 | + patched_tar_and_upload_dir.return_value = UploadedCode( |
| 3931 | + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" |
| 3932 | + ) |
| 3933 | + |
| 3934 | + framework_estimator_classes_to_kwargs = { |
| 3935 | + PyTorch: { |
| 3936 | + "framework_version": "1.5.0", |
| 3937 | + "py_version": "py3", |
| 3938 | + "instance_type": "ml.p2.xlarge", |
| 3939 | + }, |
| 3940 | + TensorFlow: { |
| 3941 | + "framework_version": "2.3", |
| 3942 | + "py_version": "py37", |
| 3943 | + "instance_type": "ml.p2.xlarge", |
| 3944 | + }, |
| 3945 | + HuggingFace: { |
| 3946 | + "pytorch_version": "1.7.1", |
| 3947 | + "py_version": "py36", |
| 3948 | + "transformers_version": "4.6.1", |
| 3949 | + "instance_type": "ml.p2.xlarge", |
| 3950 | + }, |
| 3951 | + MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, |
| 3952 | + SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, |
| 3953 | + XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, |
| 3954 | + } |
| 3955 | + jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" |
| 3956 | + jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" |
| 3957 | + for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): |
| 3958 | + estimator = framework_estimator_class( |
| 3959 | + entry_point=ENTRY_POINT, |
| 3960 | + role=ROLE, |
| 3961 | + sagemaker_session=sagemaker_session, |
| 3962 | + model_uri=jumpstart_model_uri, |
| 3963 | + instance_count=INSTANCE_COUNT, |
| 3964 | + **kwargs, |
| 3965 | + ) |
| 3966 | + |
| 3967 | + estimator.fit() |
| 3968 | + |
| 3969 | + assert sagemaker_session.train.call_args_list[0][1]["job_name"].startswith( |
| 3970 | + JUMPSTART_RESOURCE_BASE_NAME |
| 3971 | + ) |
| 3972 | + |
| 3973 | + estimator.deploy( |
| 3974 | + initial_instance_count=INSTANCE_COUNT, |
| 3975 | + instance_type=kwargs["instance_type"], |
| 3976 | + image_uri=IMAGE_URI, |
| 3977 | + source_dir=jumpstart_model_uri_2, |
| 3978 | + entry_point="inference.py", |
| 3979 | + role=ROLE, |
| 3980 | + ) |
| 3981 | + |
| 3982 | + assert sagemaker_session.create_model.call_args_list[0][0][0].startswith( |
| 3983 | + JUMPSTART_RESOURCE_BASE_NAME |
| 3984 | + ) |
| 3985 | + |
| 3986 | + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith( |
| 3987 | + JUMPSTART_RESOURCE_BASE_NAME |
| 3988 | + ) |
| 3989 | + |
| 3990 | + sagemaker_session.endpoint_from_production_variants.reset_mock() |
| 3991 | + sagemaker_session.create_model.reset_mock() |
| 3992 | + sagemaker_session.train.reset_mock() |
0 commit comments