Skip to content

Commit d7c4307

Browse files
committed
feat: jsch jumpstart estimator support (#4439)
1 parent 29718b4 commit d7c4307

File tree

19 files changed

+65
-35
lines changed

19 files changed

+65
-35
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def get_model_specs(
257257
hub_arn: Optional[str] = None,
258258
s3_client: Optional[boto3.client] = None,
259259
model_type=JumpStartModelType.OPEN_WEIGHTS,
260+
hub_arn: Optional[str] = None,
260261
) -> JumpStartModelSpecs:
261262
"""Returns model specs from JumpStart models cache.
262263

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
get_wildcard_model_version_msg,
4040
get_wildcard_proprietary_model_version_msg,
4141
)
42+
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
4243
from sagemaker.jumpstart.parameters import (
4344
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
4445
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,

src/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@
172172
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
173173
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
174174

175-
# works cross-partition
176-
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
175+
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
177176
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
178177

179178
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def _validate_model_id_and_get_type_hook():
534534
model_version=model_version,
535535
hub_arn=hub_arn,
536536
model_type=self.model_type,
537+
hub_arn=hub_arn,
537538
tolerate_vulnerable_model=tolerate_vulnerable_model,
538539
tolerate_deprecated_model=tolerate_deprecated_model,
539540
role=role,

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def get_init_kwargs(
8181
model_version: Optional[str] = None,
8282
hub_arn: Optional[str] = None,
8383
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
84+
hub_arn: Optional[str] = None,
8485
tolerate_vulnerable_model: Optional[bool] = None,
8586
tolerate_deprecated_model: Optional[bool] = None,
8687
region: Optional[str] = None,
@@ -140,6 +141,7 @@ def get_init_kwargs(
140141
model_version=model_version,
141142
hub_arn=hub_arn,
142143
model_type=model_type,
144+
hub_arn=hub_arn,
143145
role=role,
144146
region=region,
145147
instance_count=instance_count,

src/sagemaker/jumpstart/factory/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def get_deploy_kwargs(
550550
model_version: Optional[str] = None,
551551
hub_arn: Optional[str] = None,
552552
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
553+
hub_arn: Optional[str] = None,
553554
region: Optional[str] = None,
554555
initial_instance_count: Optional[int] = None,
555556
instance_type: Optional[str] = None,
@@ -584,6 +585,7 @@ def get_deploy_kwargs(
584585
model_version=model_version,
585586
hub_arn=hub_arn,
586587
model_type=model_type,
588+
hub_arn=hub_arn,
587589
region=region,
588590
initial_instance_count=initial_instance_count,
589591
instance_type=instance_type,

src/sagemaker/jumpstart/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
14201420
"model_version",
14211421
"hub_arn",
14221422
"model_type",
1423+
"hub_arn",
14231424
"initial_instance_count",
14241425
"instance_type",
14251426
"region",
@@ -1453,6 +1454,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
14531454
"model_version",
14541455
"hub_arn",
14551456
"model_type",
1457+
"hub_arn",
14561458
"region",
14571459
"tolerate_deprecated_model",
14581460
"tolerate_vulnerable_model",
@@ -1499,6 +1501,7 @@ def __init__(
14991501
self.model_version = model_version
15001502
self.hub_arn = hub_arn
15011503
self.model_type = model_type
1504+
self.hub_arn = hub_arn
15021505
self.initial_instance_count = initial_instance_count
15031506
self.instance_type = instance_type
15041507
self.region = region
@@ -1535,6 +1538,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
15351538
"model_version",
15361539
"hub_arn",
15371540
"model_type",
1541+
"hub_arn",
15381542
"instance_type",
15391543
"instance_count",
15401544
"region",
@@ -1596,6 +1600,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
15961600
"model_version",
15971601
"hub_arn",
15981602
"model_type",
1603+
"hub_arn",
15991604
}
16001605

16011606
def __init__(
@@ -1725,6 +1730,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
17251730
"model_version",
17261731
"hub_arn",
17271732
"model_type",
1733+
"hub_arn",
17281734
"region",
17291735
"inputs",
17301736
"wait",
@@ -1741,6 +1747,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
17411747
"model_version",
17421748
"hub_arn",
17431749
"model_type",
1750+
"hub_arn",
17441751
"region",
17451752
"tolerate_deprecated_model",
17461753
"tolerate_vulnerable_model",
@@ -1769,6 +1776,7 @@ def __init__(
17691776
self.model_version = model_version
17701777
self.hub_arn = hub_arn
17711778
self.model_type = model_type
1779+
self.hub_arn = hub_arn
17721780
self.region = region
17731781
self.inputs = inputs
17741782
self.wait = wait

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from __future__ import absolute_import
1515
import logging
1616
import os
17-
from typing import Any, Dict, List, Set, Optional, Tuple, Union
1817
import re
18+
from typing import Any, Dict, List, Set, Optional, Tuple, Union
1919
from urllib.parse import urlparse
2020
import boto3
2121
from packaging.version import Version

tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def add_options_to_hyperparameter(*largs, **kwargs):
453453
s3_client=mock_client,
454454
hub_arn=None,
455455
model_type=JumpStartModelType.OPEN_WEIGHTS,
456+
hub_arn=None,
456457
)
457458

458459
patched_get_model_specs.reset_mock()
@@ -516,6 +517,7 @@ def test_jumpstart_validate_all_hyperparameters(
516517
s3_client=mock_client,
517518
hub_arn=None,
518519
model_type=JumpStartModelType.OPEN_WEIGHTS,
520+
hub_arn=None,
519521
)
520522

521523
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_jumpstart_common_image_uri(
5656
s3_client=mock_client,
5757
hub_arn=None,
5858
model_type=JumpStartModelType.OPEN_WEIGHTS,
59+
hub_arn=None,
5960
)
6061
patched_verify_model_region_and_return_specs.assert_called_once()
6162

@@ -78,6 +79,7 @@ def test_jumpstart_common_image_uri(
7879
s3_client=mock_client,
7980
hub_arn=None,
8081
model_type=JumpStartModelType.OPEN_WEIGHTS,
82+
hub_arn=None,
8183
)
8284
patched_verify_model_region_and_return_specs.assert_called_once()
8385

@@ -100,6 +102,7 @@ def test_jumpstart_common_image_uri(
100102
s3_client=mock_client,
101103
hub_arn=None,
102104
model_type=JumpStartModelType.OPEN_WEIGHTS,
105+
hub_arn=None,
103106
)
104107
patched_verify_model_region_and_return_specs.assert_called_once()
105108

@@ -122,6 +125,7 @@ def test_jumpstart_common_image_uri(
122125
s3_client=mock_client,
123126
hub_arn=None,
124127
model_type=JumpStartModelType.OPEN_WEIGHTS,
128+
hub_arn=None,
125129
)
126130
patched_verify_model_region_and_return_specs.assert_called_once()
127131

tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode
123123
region=region,
124124
model_id=model_id,
125125
version=model_version,
126+
model_type=JumpStartModelType.OPEN_WEIGHTS,
126127
hub_arn=None,
127128
s3_client=mock_client,
128-
model_type=JumpStartModelType.OPEN_WEIGHTS,
129129
)
130130

131131
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from unittest.mock import Mock
1616
from sagemaker.jumpstart.types import HubArnExtractedInfo
17-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1817
from sagemaker.jumpstart.curated_hub import utils
18+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
19+
from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo
1920

2021

2122
def test_get_info_from_hub_resource_arn():

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,31 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
137137
> 0
138138
)
139139

140+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
141+
def test_jumpstart_models_cache_get_model_specs(mock_cache):
142+
mock_cache.get_specs = Mock()
143+
mock_cache.get_hub_model = Mock()
144+
model_id, version = "pytorch-ic-mobilenet-v2", "*"
145+
region = "us-west-2"
146+
147+
accessors.JumpStartModelsAccessor.get_model_specs(
148+
region=region, model_id=model_id, version=version
149+
)
150+
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
151+
mock_cache.get_hub_model.assert_not_called()
152+
153+
accessors.JumpStartModelsAccessor.get_model_specs(
154+
region=region,
155+
model_id=model_id,
156+
version=version,
157+
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
158+
)
159+
mock_cache.get_hub_model.assert_called_once_with(
160+
hub_model_arn=(
161+
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
162+
)
163+
)
164+
140165
# necessary because accessors is a static module
141166
reload(accessors)
142167

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,4 +751,5 @@ def test_get_model_url(
751751
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
752752
hub_arn=None,
753753
model_type=JumpStartModelType.OPEN_WEIGHTS,
754+
hub_arn=None,
754755
)

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,35 +1214,6 @@ def test_mime_type_enum_from_str():
12141214
assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type
12151215

12161216

1217-
def test_extract_info_from_hub_content_arn():
1218-
model_arn = (
1219-
"arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2"
1220-
)
1221-
assert utils.extract_info_from_hub_content_arn(model_arn) == (
1222-
"MockHub",
1223-
"us-west-2",
1224-
"my-mock-model",
1225-
"1.0.2",
1226-
)
1227-
1228-
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub"
1229-
assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None)
1230-
1231-
invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123"
1232-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1233-
1234-
invalid_arn = "nonsense-string"
1235-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1236-
1237-
invalid_arn = ""
1238-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1239-
1240-
invalid_arn = (
1241-
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0"
1242-
)
1243-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1244-
1245-
12461217
class TestIsValidModelId(TestCase):
12471218
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12481219
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25-
HubDataType,
25+
HubContentType,
2626
JumpStartCachedContentKey,
2727
JumpStartCachedContentValue,
2828
JumpStartModelSpecs,
@@ -253,6 +253,9 @@ def patched_retrieval_function(
253253
model_type=JumpStartModelType.PROPRIETARY,
254254
)
255255
)
256+
# TODO: Implement
257+
if datatype == HubContentType.HUB:
258+
return None
256259

257260
raise ValueError(f"Bad value for datatype: {datatype}")
258261

tests/unit/sagemaker/model_uris/jumpstart/test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_jumpstart_common_model_uri(
5454
s3_client=mock_client,
5555
hub_arn=None,
5656
model_type=JumpStartModelType.OPEN_WEIGHTS,
57+
hub_arn=None,
5758
)
5859
patched_verify_model_region_and_return_specs.assert_called_once()
5960

@@ -73,6 +74,7 @@ def test_jumpstart_common_model_uri(
7374
s3_client=mock_client,
7475
hub_arn=None,
7576
model_type=JumpStartModelType.OPEN_WEIGHTS,
77+
hub_arn=None,
7678
)
7779
patched_verify_model_region_and_return_specs.assert_called_once()
7880

@@ -93,6 +95,7 @@ def test_jumpstart_common_model_uri(
9395
s3_client=mock_client,
9496
hub_arn=None,
9597
model_type=JumpStartModelType.OPEN_WEIGHTS,
98+
hub_arn=None,
9699
)
97100
patched_verify_model_region_and_return_specs.assert_called_once()
98101

@@ -113,6 +116,7 @@ def test_jumpstart_common_model_uri(
113116
s3_client=mock_client,
114117
hub_arn=None,
115118
model_type=JumpStartModelType.OPEN_WEIGHTS,
119+
hub_arn=None,
116120
)
117121
patched_verify_model_region_and_return_specs.assert_called_once()
118122

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_jumpstart_resource_requirements(
5757
s3_client=mock_client,
5858
hub_arn=None,
5959
model_type=JumpStartModelType.OPEN_WEIGHTS,
60+
hub_arn=None,
6061
)
6162
patched_get_model_specs.reset_mock()
6263

tests/unit/sagemaker/script_uris/jumpstart/test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_jumpstart_common_script_uri(
5454
s3_client=mock_client,
5555
hub_arn=None,
5656
model_type=JumpStartModelType.OPEN_WEIGHTS,
57+
hub_arn=None,
5758
)
5859
patched_verify_model_region_and_return_specs.assert_called_once()
5960

@@ -73,6 +74,7 @@ def test_jumpstart_common_script_uri(
7374
s3_client=mock_client,
7475
hub_arn=None,
7576
model_type=JumpStartModelType.OPEN_WEIGHTS,
77+
hub_arn=None,
7678
)
7779
patched_verify_model_region_and_return_specs.assert_called_once()
7880

@@ -93,6 +95,7 @@ def test_jumpstart_common_script_uri(
9395
s3_client=mock_client,
9496
hub_arn=None,
9597
model_type=JumpStartModelType.OPEN_WEIGHTS,
98+
hub_arn=None,
9699
)
97100
patched_verify_model_region_and_return_specs.assert_called_once()
98101

@@ -113,6 +116,7 @@ def test_jumpstart_common_script_uri(
113116
s3_client=mock_client,
114117
hub_arn=None,
115118
model_type=JumpStartModelType.OPEN_WEIGHTS,
119+
hub_arn=None,
116120
)
117121
patched_verify_model_region_and_return_specs.assert_called_once()
118122

0 commit comments

Comments
 (0)