Skip to content

Commit da1b642

Browse files
committed
rebase with master
1 parent c450168 commit da1b642

File tree

9 files changed

+24
-13
lines changed

9 files changed

+24
-13
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
get_wildcard_model_version_msg,
4141
get_wildcard_proprietary_model_version_msg,
4242
)
43-
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
4443
from sagemaker.jumpstart.parameters import (
4544
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
4645
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
@@ -453,7 +452,9 @@ def _retrieval_function(
453452
formatted_body, _ = self._get_json_file(id_info, data_type)
454453
model_specs = JumpStartModelSpecs(formatted_body)
455454
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
456-
return JumpStartCachedContentValue(formatted_content=model_specs)
455+
return JumpStartCachedContentValue(
456+
formatted_content=model_specs
457+
)
457458

458459
if data_type == HubContentType.MODEL:
459460
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
@@ -478,6 +479,7 @@ def _retrieval_function(
478479
return JumpStartCachedContentValue(
479480
formatted_content=model_specs
480481
)
482+
481483
if data_type == HubType.HUB:
482484
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
483485
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)

src/sagemaker/jumpstart/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from copy import deepcopy
1616
from enum import Enum
1717
from typing import Any, Dict, List, Optional, Set, Union
18-
from sagemaker.session import Session
1918
from sagemaker.utils import get_instance_type_family, format_tags, Tags
2019
from sagemaker.enums import EndpointType
2120
from sagemaker.model_metrics import ModelMetrics

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import dataclasses
2525
import json
2626

27-
import sagemaker
28-
2927

3028
class _UTCFormatter(logging.Formatter):
3129
"""Class that overrides the default local time provider in log formatter."""

src/sagemaker/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import random
2323
import re
2424
import shutil
25-
import sys
2625
import tarfile
2726
import tempfile
2827
import time

tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def test_jumpstart_instance_types_from_hub(patched_get_model_specs):
206206
model_id=model_id,
207207
version=model_version,
208208
s3_client=mock_client,
209+
model_type=JumpStartModelType.OPEN_WEIGHTS,
209210
)
210211

211212
patched_get_model_specs.reset_mock()
@@ -227,6 +228,7 @@ def test_jumpstart_instance_types_from_hub(patched_get_model_specs):
227228
model_id=model_id,
228229
version=model_version,
229230
s3_client=mock_client,
231+
model_type=JumpStartModelType.OPEN_WEIGHTS,
230232
)
231233

232234
patched_get_model_specs.reset_mock()
@@ -253,6 +255,7 @@ def test_jumpstart_instance_types_from_hub(patched_get_model_specs):
253255
model_id=model_id,
254256
version=model_version,
255257
s3_client=mock_client,
258+
model_type=JumpStartModelType.OPEN_WEIGHTS,
256259
)
257260

258261
patched_get_model_specs.reset_mock()
@@ -282,6 +285,7 @@ def test_jumpstart_instance_types_from_hub(patched_get_model_specs):
282285
model_id=model_id,
283286
version=model_version,
284287
s3_client=mock_client,
288+
model_type=JumpStartModelType.OPEN_WEIGHTS,
285289
)
286290

287291
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def test_prepacked(
292292
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_deploy_kwargs")
293293
@mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_fit_kwargs")
294294
@mock.patch("sagemaker.jumpstart.curated_hub.utils.construct_hub_arn_from_name")
295-
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
295+
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
296296
@mock.patch("sagemaker.jumpstart.factory.model.Session")
297297
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
298298
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -309,7 +309,7 @@ def test_hub_model(
309309
mock_get_model_specs: mock.Mock,
310310
mock_session_estimator: mock.Mock,
311311
mock_session_model: mock.Mock,
312-
mock_is_valid_model_id: mock.Mock,
312+
mock_validate_model_id_and_get_type: mock.Mock,
313313
mock_construct_hub_arn_from_name: mock.Mock,
314314
mock_retrieve_estimator_fit_kwargs: mock.Mock,
315315
mock_retrieve_model_deploy_kwargs: mock.Mock,
@@ -321,7 +321,7 @@ def test_hub_model(
321321
mock_get_caller_identity.return_value = "123456789123"
322322
mock_estimator_deploy.return_value = default_predictor
323323

324-
mock_is_valid_model_id.return_value = True
324+
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
325325

326326
model_id, _ = "pytorch-hub-model-1", "*"
327327
hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub"

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
137137
> 0
138138
)
139139

140+
140141
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
141142
def test_jumpstart_models_cache_get_model_specs(mock_cache):
142143
mock_cache.get_specs = Mock()
@@ -147,7 +148,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
147148
accessors.JumpStartModelsAccessor.get_model_specs(
148149
region=region, model_id=model_id, version=version
149150
)
150-
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
151+
mock_cache.get_specs.assert_called_once_with(
152+
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
153+
)
151154
mock_cache.get_hub_model.assert_not_called()
152155

153156
accessors.JumpStartModelsAccessor.get_model_specs(

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
2929
JumpStartModelsCache,
3030
)
31-
from sagemaker.session_settings import SessionSettings
3231
from sagemaker.jumpstart.constants import (
3332
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
3433
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -559,8 +558,8 @@ def test_jumpstart_proprietary_cache_accepts_input_parameters():
559558
)
560559
assert cache.get_region() == region
561560
assert cache.get_bucket() == bucket
562-
assert cache._s3_cache._max_cache_items == max_s3_cache_items
563-
assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon
561+
assert cache._content_cache._max_cache_items == max_s3_cache_items
562+
assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon
564563
assert (
565564
cache._proprietary_model_id_manifest_key_cache._max_cache_items
566565
== max_semantic_version_cache_items

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ def patched_retrieval_function(
253253
model_type=JumpStartModelType.PROPRIETARY,
254254
)
255255
)
256+
257+
if datatype == HubContentType.MODEL:
258+
_, _, _, model_name, model_version = id_info.split("/")
259+
return JumpStartCachedContentValue(
260+
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
261+
)
262+
256263
# TODO: Implement
257264
if datatype == HubType.HUB:
258265
return None

0 commit comments

Comments
 (0)