Skip to content

Commit 37f8f6e

Browse files
evakravibenieric
authored andcommitted
feat: s3 prefix model data for JumpStartModel
1 parent 16d1556 commit 37f8f6e

File tree

5 files changed

+217
-8
lines changed

5 files changed

+217
-8
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
206206
def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
207207
"""Sets model data based on default or override, returns full kwargs."""
208208

209-
model_data = kwargs.model_data
210-
211-
kwargs.model_data = model_data or model_uris.retrieve(
209+
model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve(
212210
model_scope=JumpStartScriptScope.INFERENCE,
213211
model_id=kwargs.model_id,
214212
model_version=kwargs.model_version,
@@ -218,6 +216,23 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
218216
sagemaker_session=kwargs.sagemaker_session,
219217
)
220218

219+
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
220+
if kwargs.model_data:
221+
JUMPSTART_LOGGER.info(
222+
"S3 prefix model_data detected for JumpStartModel: '%s'. "
223+
"Converting to S3DataSource dictionary.",
224+
model_data,
225+
)
226+
model_data = {
227+
"S3DataSource": {
228+
"S3Uri": model_data,
229+
"S3DataType": "S3Prefix",
230+
"CompressionType": "None",
231+
}
232+
}
233+
234+
kwargs.model_data = model_data
235+
221236
return kwargs
222237

223238

@@ -496,7 +511,7 @@ def get_init_kwargs(
496511
instance_type: Optional[str] = None,
497512
region: Optional[str] = None,
498513
image_uri: Optional[Union[str, PipelineVariable]] = None,
499-
model_data: Optional[Union[str, PipelineVariable]] = None,
514+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
500515
role: Optional[str] = None,
501516
predictor_cls: Optional[callable] = None,
502517
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
region: Optional[str] = None,
5454
instance_type: Optional[str] = None,
5555
image_uri: Optional[Union[str, PipelineVariable]] = None,
56-
model_data: Optional[Union[str, PipelineVariable]] = None,
56+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
5757
role: Optional[str] = None,
5858
predictor_cls: Optional[callable] = None,
5959
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
@@ -95,8 +95,8 @@ def __init__(
9595
instance_type (Optional[str]): The EC2 instance type to use when provisioning a hosting
9696
endpoint. (Default: None).
9797
image_uri (Optional[Union[str, PipelineVariable]]): A Docker image URI. (Default: None).
98-
model_data (Optional[Union[str, PipelineVariable]]): The S3 location of a SageMaker
99-
model data ``.tar.gz`` file. (Default: None).
98+
model_data (Optional[Union[str, PipelineVariable, dict]]): Location
99+
of SageMaker model data. (Default: None).
100100
role (Optional[str]): An AWS IAM role (either name or full ARN). The Amazon
101101
SageMaker training jobs and APIs that create Amazon SageMaker
102102
endpoints use this role to access training data and model

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def __init__(
752752
region: Optional[str] = None,
753753
instance_type: Optional[str] = None,
754754
image_uri: Optional[Union[str, Any]] = None,
755-
model_data: Optional[Union[str, Any]] = None,
755+
model_data: Optional[Union[str, Any, dict]] = None,
756756
role: Optional[str] = None,
757757
predictor_cls: Optional[callable] = None,
758758
env: Optional[Dict[str, Union[str, Any]]] = None,

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,93 @@
17081708
"default_accept_type": "application/json",
17091709
},
17101710
},
1711+
"model_data_s3_prefix_model": {
1712+
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
1713+
"url": "https://huggingface.co/google/flan-t5-xxl",
1714+
"version": "1.0.1",
1715+
"min_sdk_version": "2.130.0",
1716+
"training_supported": False,
1717+
"incremental_training_supported": False,
1718+
"hosting_ecr_specs": {
1719+
"framework": "pytorch",
1720+
"framework_version": "1.12.0",
1721+
"py_version": "py38",
1722+
"huggingface_transformers_version": "4.17.0",
1723+
},
1724+
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
1725+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz",
1726+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/",
1727+
"hosting_prepacked_artifact_version": "1.0.1",
1728+
"inference_vulnerable": False,
1729+
"inference_dependencies": [
1730+
"accelerate==0.16.0",
1731+
"bitsandbytes==0.37.0",
1732+
"filelock==3.9.0",
1733+
"huggingface_hub==0.12.0",
1734+
"regex==2022.7.9",
1735+
"tokenizers==0.13.2",
1736+
"transformers==4.26.0",
1737+
],
1738+
"inference_vulnerabilities": [],
1739+
"training_vulnerable": False,
1740+
"training_dependencies": [],
1741+
"training_vulnerabilities": [],
1742+
"deprecated": False,
1743+
"inference_environment_variables": [
1744+
{
1745+
"name": "SAGEMAKER_PROGRAM",
1746+
"type": "text",
1747+
"default": "inference.py",
1748+
"scope": "container",
1749+
},
1750+
{
1751+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
1752+
"type": "text",
1753+
"default": "/opt/ml/model/code",
1754+
"scope": "container",
1755+
},
1756+
{
1757+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
1758+
"type": "text",
1759+
"default": "20",
1760+
"scope": "container",
1761+
},
1762+
{
1763+
"name": "MODEL_CACHE_ROOT",
1764+
"type": "text",
1765+
"default": "/opt/ml/model",
1766+
"scope": "container",
1767+
},
1768+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
1769+
{
1770+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
1771+
"type": "text",
1772+
"default": "1",
1773+
"scope": "container",
1774+
},
1775+
{
1776+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
1777+
"type": "text",
1778+
"default": "3600",
1779+
"scope": "container",
1780+
},
1781+
],
1782+
"metrics": [],
1783+
"default_inference_instance_type": "ml.g5.12xlarge",
1784+
"supported_inference_instance_types": [
1785+
"ml.g5.12xlarge",
1786+
"ml.g5.24xlarge",
1787+
"ml.p3.8xlarge",
1788+
"ml.p3.16xlarge",
1789+
"ml.g4dn.12xlarge",
1790+
],
1791+
"predictor_specs": {
1792+
"supported_content_types": ["application/x-text"],
1793+
"supported_accept_types": ["application/json;verbose", "application/json"],
1794+
"default_content_type": "application/x-text",
1795+
"default_accept_type": "application/json",
1796+
},
1797+
},
17111798
"no-supported-instance-types-model": {
17121799
"model_id": "pytorch-ic-mobilenet-v2",
17131800
"url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,113 @@ def test_jumpstart_model_package_arn_unsupported_region(
678678
"us-east-2. Please try one of the following regions: us-west-2, us-east-1."
679679
)
680680

681+
@mock.patch("sagemaker.utils.sagemaker_timestamp")
682+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
683+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
684+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
685+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
686+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
687+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
688+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info")
689+
def test_model_data_s3_prefix_override(
690+
self,
691+
mock_js_info_logger: mock.Mock,
692+
mock_model_deploy: mock.Mock,
693+
mock_model_init: mock.Mock,
694+
mock_get_model_specs: mock.Mock,
695+
mock_session: mock.Mock,
696+
mock_is_valid_model_id: mock.Mock,
697+
mock_sagemaker_timestamp: mock.Mock,
698+
):
699+
mock_model_deploy.return_value = default_predictor
700+
701+
mock_sagemaker_timestamp.return_value = "7777"
702+
703+
mock_is_valid_model_id.return_value = True
704+
model_id, _ = "js-trainable-model", "*"
705+
706+
mock_get_model_specs.side_effect = get_special_model_spec
707+
708+
mock_session.return_value = sagemaker_session
709+
710+
JumpStartModel(model_id=model_id, model_data="s3://some-bucket/path/to/prefix/")
711+
712+
mock_model_init.assert_called_once_with(
713+
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/"
714+
"autogluon-inference:0.4.3-gpu-py38",
715+
model_data={
716+
"S3DataSource": {
717+
"S3Uri": "s3://some-bucket/path/to/prefix/",
718+
"S3DataType": "S3Prefix",
719+
"CompressionType": "None",
720+
}
721+
},
722+
source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-"
723+
"tarballs/autogluon/inference/classification/v1.0.0/sourcedir.tar.gz",
724+
entry_point="inference.py",
725+
env={
726+
"SAGEMAKER_PROGRAM": "inference.py",
727+
"ENDPOINT_SERVER_TIMEOUT": "3600",
728+
"MODEL_CACHE_ROOT": "/opt/ml/model",
729+
"SAGEMAKER_ENV": "1",
730+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
731+
},
732+
predictor_cls=Predictor,
733+
role=execution_role,
734+
sagemaker_session=sagemaker_session,
735+
enable_network_isolation=False,
736+
name="blahblahblah-7777",
737+
)
738+
739+
mock_js_info_logger.assert_called_with(
740+
"S3 prefix model_data detected for JumpStartModel: '%s'. Converting to S3DataSource dictionary.",
741+
"s3://some-bucket/path/to/prefix/",
742+
)
743+
744+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
745+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
746+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
747+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
748+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
749+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
750+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info")
751+
def test_model_data_s3_prefix_model(
752+
self,
753+
mock_js_info_logger: mock.Mock,
754+
mock_model_deploy: mock.Mock,
755+
mock_model_init: mock.Mock,
756+
mock_get_model_specs: mock.Mock,
757+
mock_session: mock.Mock,
758+
mock_is_valid_model_id: mock.Mock,
759+
):
760+
mock_model_deploy.return_value = default_predictor
761+
762+
mock_is_valid_model_id.return_value = True
763+
model_id, _ = "model_data_s3_prefix_model", "*"
764+
765+
mock_get_model_specs.side_effect = get_special_model_spec
766+
767+
mock_session.return_value = sagemaker_session
768+
769+
JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge")
770+
771+
mock_model_init.assert_called_once_with(
772+
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38",
773+
model_data={
774+
"S3DataSource": {
775+
"S3Uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/",
776+
"S3DataType": "S3Prefix",
777+
"CompressionType": "None",
778+
}
779+
},
780+
predictor_cls=Predictor,
781+
role=execution_role,
782+
sagemaker_session=sagemaker_session,
783+
enable_network_isolation=False,
784+
)
785+
786+
mock_js_info_logger.assert_not_called()
787+
681788

682789
def test_jumpstart_model_requires_model_id():
683790
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)