Skip to content

Commit 5aeaa35

Browse files
committed
feat: add proprietary manifest/specs parsing
add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models
1 parent 5158749 commit 5aeaa35

File tree

40 files changed

+1360
-411
lines changed

40 files changed

+1360
-411
lines changed

src/sagemaker/accept_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.jumpstart.enums import JumpStartModelType
1920
from sagemaker.session import Session
2021

2122

@@ -72,6 +73,7 @@ def retrieve_default(
7273
region: Optional[str] = None,
7374
model_id: Optional[str] = None,
7475
model_version: Optional[str] = None,
76+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE,
7577
tolerate_vulnerable_model: bool = False,
7678
tolerate_deprecated_model: bool = False,
7779
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,

src/sagemaker/base_predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@
5858
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5959

6060
from sagemaker.lineage.context import EndpointContext
61-
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
61+
from sagemaker.compute_resource_requirements.resource_requirements import (
62+
ResourceRequirements,
63+
)
6264

6365
LOGGER = logging.getLogger("sagemaker")
6466

src/sagemaker/content_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.jumpstart.enums import JumpStartModelType
1920
from sagemaker.session import Session
2021

2122

@@ -72,6 +73,7 @@ def retrieve_default(
7273
region: Optional[str] = None,
7374
model_id: Optional[str] = None,
7475
model_version: Optional[str] = None,
76+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE,
7577
tolerate_vulnerable_model: bool = False,
7678
tolerate_deprecated_model: bool = False,
7779
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,

src/sagemaker/deserializers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
3737
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
38+
from sagemaker.jumpstart.enums import JumpStartModelType
3839
from sagemaker.session import Session
3940

4041

@@ -95,6 +96,7 @@ def retrieve_default(
9596
tolerate_vulnerable_model: bool = False,
9697
tolerate_deprecated_model: bool = False,
9798
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
99+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE,
98100
) -> BaseDeserializer:
99101
"""Retrieves the default deserializer for the model matching the given arguments.
100102

src/sagemaker/jumpstart/accessors.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from sagemaker.deprecations import deprecated
2020
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
21+
from sagemaker.jumpstart.enums import JumpStartModelType
2122
from sagemaker.jumpstart import cache
2223
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
2324

@@ -197,7 +198,9 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
197198

198199
@staticmethod
199200
def _get_manifest(
200-
region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None
201+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
202+
s3_client: Optional[boto3.client] = None,
203+
model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE,
201204
) -> List[JumpStartModelHeader]:
202205
"""Return entire JumpStart models manifest.
203206
@@ -215,13 +218,19 @@ def _get_manifest(
215218
additional_kwargs.update({"s3_client": s3_client})
216219

217220
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
218-
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region
221+
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs},
222+
region,
219223
)
220224
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
221-
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
225+
return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore
222226

223227
@staticmethod
224-
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
228+
def get_model_header(
229+
region: str,
230+
model_id: str,
231+
version: str,
232+
model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE,
233+
) -> JumpStartModelHeader:
225234
"""Returns model header from JumpStart models cache.
226235
227236
Args:
@@ -234,12 +243,18 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
234243
)
235244
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
236245
return JumpStartModelsAccessor._cache.get_header( # type: ignore
237-
model_id=model_id, semantic_version_str=version
246+
model_id=model_id,
247+
semantic_version_str=version,
248+
model_type=model_type,
238249
)
239250

240251
@staticmethod
241252
def get_model_specs(
242-
region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None
253+
region: str,
254+
model_id: str,
255+
version: str,
256+
s3_client: Optional[boto3.client] = None,
257+
model_type=JumpStartModelType.OPEN_SOURCE,
243258
) -> JumpStartModelSpecs:
244259
"""Returns model specs from JumpStart models cache.
245260
@@ -260,7 +275,7 @@ def get_model_specs(
260275
)
261276
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
262277
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
263-
model_id=model_id, semantic_version_str=version
278+
model_id=model_id, version_str=version, model_type=model_type
264279
)
265280

266281
@staticmethod

0 commit comments

Comments
 (0)