Skip to content

Commit b7b4b3f

Browse files
committed
feat: s3 prefix model data for JumpStartModel
1 parent 410ab2c commit b7b4b3f

File tree

5 files changed

+217
-8
lines changed

5 files changed

+217
-8
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
206206
def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
207207
"""Sets model data based on default or override, returns full kwargs."""
208208

209-
model_data = kwargs.model_data
210-
211-
kwargs.model_data = model_data or model_uris.retrieve(
209+
model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve(
212210
model_scope=JumpStartScriptScope.INFERENCE,
213211
model_id=kwargs.model_id,
214212
model_version=kwargs.model_version,
@@ -218,6 +216,23 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
218216
sagemaker_session=kwargs.sagemaker_session,
219217
)
220218

219+
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
220+
if kwargs.model_data:
221+
JUMPSTART_LOGGER.info(
222+
"S3 prefix model_data detected for JumpStartModel: '%s'. "
223+
"Converting to S3DataSource dictionary.",
224+
model_data,
225+
)
226+
model_data = {
227+
"S3DataSource": {
228+
"S3Uri": model_data,
229+
"S3DataType": "S3Prefix",
230+
"CompressionType": "None",
231+
}
232+
}
233+
234+
kwargs.model_data = model_data
235+
221236
return kwargs
222237

223238

@@ -494,7 +509,7 @@ def get_init_kwargs(
494509
instance_type: Optional[str] = None,
495510
region: Optional[str] = None,
496511
image_uri: Optional[Union[str, PipelineVariable]] = None,
497-
model_data: Optional[Union[str, PipelineVariable]] = None,
512+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
498513
role: Optional[str] = None,
499514
predictor_cls: Optional[callable] = None,
500515
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
region: Optional[str] = None,
5454
instance_type: Optional[str] = None,
5555
image_uri: Optional[Union[str, PipelineVariable]] = None,
56-
model_data: Optional[Union[str, PipelineVariable]] = None,
56+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
5757
role: Optional[str] = None,
5858
predictor_cls: Optional[callable] = None,
5959
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
@@ -95,8 +95,8 @@ def __init__(
9595
instance_type (Optional[str]): The EC2 instance type to use when provisioning a hosting
9696
endpoint. (Default: None).
9797
image_uri (Optional[Union[str, PipelineVariable]]): A Docker image URI. (Default: None).
98-
model_data (Optional[Union[str, PipelineVariable]]): The S3 location of a SageMaker
99-
model data ``.tar.gz`` file. (Default: None).
98+
model_data (Optional[Union[str, PipelineVariable, dict]]): Location
99+
of SageMaker model data. (Default: None).
100100
role (Optional[str]): An AWS IAM role (either name or full ARN). The Amazon
101101
SageMaker training jobs and APIs that create Amazon SageMaker
102102
endpoints use this role to access training data and model

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def __init__(
626626
region: Optional[str] = None,
627627
instance_type: Optional[str] = None,
628628
image_uri: Optional[Union[str, Any]] = None,
629-
model_data: Optional[Union[str, Any]] = None,
629+
model_data: Optional[Union[str, Any, dict]] = None,
630630
role: Optional[str] = None,
631631
predictor_cls: Optional[callable] = None,
632632
env: Optional[Dict[str, Union[str, Any]]] = None,

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,93 @@
10471047
"default_accept_type": "application/json",
10481048
},
10491049
},
1050+
"model_data_s3_prefix_model": {
1051+
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
1052+
"url": "https://huggingface.co/google/flan-t5-xxl",
1053+
"version": "1.0.1",
1054+
"min_sdk_version": "2.130.0",
1055+
"training_supported": False,
1056+
"incremental_training_supported": False,
1057+
"hosting_ecr_specs": {
1058+
"framework": "pytorch",
1059+
"framework_version": "1.12.0",
1060+
"py_version": "py38",
1061+
"huggingface_transformers_version": "4.17.0",
1062+
},
1063+
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
1064+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz",
1065+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/",
1066+
"hosting_prepacked_artifact_version": "1.0.1",
1067+
"inference_vulnerable": False,
1068+
"inference_dependencies": [
1069+
"accelerate==0.16.0",
1070+
"bitsandbytes==0.37.0",
1071+
"filelock==3.9.0",
1072+
"huggingface_hub==0.12.0",
1073+
"regex==2022.7.9",
1074+
"tokenizers==0.13.2",
1075+
"transformers==4.26.0",
1076+
],
1077+
"inference_vulnerabilities": [],
1078+
"training_vulnerable": False,
1079+
"training_dependencies": [],
1080+
"training_vulnerabilities": [],
1081+
"deprecated": False,
1082+
"inference_environment_variables": [
1083+
{
1084+
"name": "SAGEMAKER_PROGRAM",
1085+
"type": "text",
1086+
"default": "inference.py",
1087+
"scope": "container",
1088+
},
1089+
{
1090+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
1091+
"type": "text",
1092+
"default": "/opt/ml/model/code",
1093+
"scope": "container",
1094+
},
1095+
{
1096+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
1097+
"type": "text",
1098+
"default": "20",
1099+
"scope": "container",
1100+
},
1101+
{
1102+
"name": "MODEL_CACHE_ROOT",
1103+
"type": "text",
1104+
"default": "/opt/ml/model",
1105+
"scope": "container",
1106+
},
1107+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
1108+
{
1109+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
1110+
"type": "text",
1111+
"default": "1",
1112+
"scope": "container",
1113+
},
1114+
{
1115+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
1116+
"type": "text",
1117+
"default": "3600",
1118+
"scope": "container",
1119+
},
1120+
],
1121+
"metrics": [],
1122+
"default_inference_instance_type": "ml.g5.12xlarge",
1123+
"supported_inference_instance_types": [
1124+
"ml.g5.12xlarge",
1125+
"ml.g5.24xlarge",
1126+
"ml.p3.8xlarge",
1127+
"ml.p3.16xlarge",
1128+
"ml.g4dn.12xlarge",
1129+
],
1130+
"predictor_specs": {
1131+
"supported_content_types": ["application/x-text"],
1132+
"supported_accept_types": ["application/json;verbose", "application/json"],
1133+
"default_content_type": "application/x-text",
1134+
"default_accept_type": "application/json",
1135+
},
1136+
},
10501137
"no-supported-instance-types-model": {
10511138
"model_id": "pytorch-ic-mobilenet-v2",
10521139
"url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,113 @@ def test_jumpstart_model_package_arn_unsupported_region(
678678
"us-east-2. Please try one of the following regions: us-west-2, us-east-1."
679679
)
680680

681+
@mock.patch("sagemaker.utils.sagemaker_timestamp")
682+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
683+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
684+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
685+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
686+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
687+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
688+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info")
689+
def test_model_data_s3_prefix_override(
690+
self,
691+
mock_js_info_logger: mock.Mock,
692+
mock_model_deploy: mock.Mock,
693+
mock_model_init: mock.Mock,
694+
mock_get_model_specs: mock.Mock,
695+
mock_session: mock.Mock,
696+
mock_is_valid_model_id: mock.Mock,
697+
mock_sagemaker_timestamp: mock.Mock,
698+
):
699+
mock_model_deploy.return_value = default_predictor
700+
701+
mock_sagemaker_timestamp.return_value = "7777"
702+
703+
mock_is_valid_model_id.return_value = True
704+
model_id, _ = "js-trainable-model", "*"
705+
706+
mock_get_model_specs.side_effect = get_special_model_spec
707+
708+
mock_session.return_value = sagemaker_session
709+
710+
JumpStartModel(model_id=model_id, model_data="s3://some-bucket/path/to/prefix/")
711+
712+
mock_model_init.assert_called_once_with(
713+
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/"
714+
"autogluon-inference:0.4.3-gpu-py38",
715+
model_data={
716+
"S3DataSource": {
717+
"S3Uri": "s3://some-bucket/path/to/prefix/",
718+
"S3DataType": "S3Prefix",
719+
"CompressionType": "None",
720+
}
721+
},
722+
source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-"
723+
"tarballs/autogluon/inference/classification/v1.0.0/sourcedir.tar.gz",
724+
entry_point="inference.py",
725+
env={
726+
"SAGEMAKER_PROGRAM": "inference.py",
727+
"ENDPOINT_SERVER_TIMEOUT": "3600",
728+
"MODEL_CACHE_ROOT": "/opt/ml/model",
729+
"SAGEMAKER_ENV": "1",
730+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
731+
},
732+
predictor_cls=Predictor,
733+
role=execution_role,
734+
sagemaker_session=sagemaker_session,
735+
enable_network_isolation=False,
736+
name="blahblahblah-7777",
737+
)
738+
739+
mock_js_info_logger.assert_called_with(
740+
"S3 prefix model_data detected for JumpStartModel: '%s'. Converting to S3DataSource dictionary.",
741+
"s3://some-bucket/path/to/prefix/",
742+
)
743+
744+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
745+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
746+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
747+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
748+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
749+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
750+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info")
751+
def test_model_data_s3_prefix_model(
752+
self,
753+
mock_js_info_logger: mock.Mock,
754+
mock_model_deploy: mock.Mock,
755+
mock_model_init: mock.Mock,
756+
mock_get_model_specs: mock.Mock,
757+
mock_session: mock.Mock,
758+
mock_is_valid_model_id: mock.Mock,
759+
):
760+
mock_model_deploy.return_value = default_predictor
761+
762+
mock_is_valid_model_id.return_value = True
763+
model_id, _ = "model_data_s3_prefix_model", "*"
764+
765+
mock_get_model_specs.side_effect = get_special_model_spec
766+
767+
mock_session.return_value = sagemaker_session
768+
769+
JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge")
770+
771+
mock_model_init.assert_called_once_with(
772+
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38",
773+
model_data={
774+
"S3DataSource": {
775+
"S3Uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/",
776+
"S3DataType": "S3Prefix",
777+
"CompressionType": "None",
778+
}
779+
},
780+
predictor_cls=Predictor,
781+
role=execution_role,
782+
sagemaker_session=sagemaker_session,
783+
enable_network_isolation=False,
784+
)
785+
786+
mock_js_info_logger.assert_not_called()
787+
681788

682789
def test_jumpstart_model_requires_model_id():
683790
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)