Skip to content

Commit 29a0740

Browse files
committed
fix: more linting :/
1 parent 6b107a2 commit 29a0740

File tree

3 files changed

+14
-46
lines changed

3 files changed

+14
-46
lines changed

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from sagemaker.jumpstart.enums import JumpStartScriptScope
1919
from sagemaker.jumpstart.curated_hub import utils
2020
from unittest.mock import patch
21-
from sagemaker.jumpstart.curated_hub.types import CuratedHubUnsupportedFlag, HubContentSummary
21+
from sagemaker.jumpstart.curated_hub.types import (
22+
CuratedHubUnsupportedFlag,
23+
HubContentSummary,
24+
summary_from_list_api_response,
25+
summary_list_from_list_api_response,
26+
)
2227
from sagemaker.jumpstart.types import HubContentType
2328

2429

@@ -337,7 +342,7 @@ def test_find_all_tags_for_jumpstart_model_filters_non_jumpstart_models(mock_spe
337342

338343
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
339344
def test_summary_from_list_api_response(mock_spec_util):
340-
test = utils.summary_from_list_api_response(
345+
test = summary_from_list_api_response(
341346
{
342347
"HubContentArn": "test_arn",
343348
"HubContentName": "test_name",
@@ -366,7 +371,7 @@ def test_summary_from_list_api_response(mock_spec_util):
366371

367372
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
368373
def test_summaries_from_list_api_response(mock_spec_util):
369-
test = utils.summary_list_from_list_api_response(
374+
test = summary_list_from_list_api_response(
370375
{
371376
"HubContentSummaries": [
372377
{

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
8181
accessors.JumpStartModelsAccessor.get_model_specs(
8282
region=region, model_id=model_id, version=version
8383
)
84-
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
84+
mock_cache.get_specs.assert_called_once_with(
85+
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
86+
)
8587
mock_cache.get_hub_model.assert_not_called()
8688

8789
accessors.JumpStartModelsAccessor.get_model_specs(
@@ -96,6 +98,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
9698
)
9799
)
98100

101+
# necessary because accessors is a static module
102+
reload(accessors)
103+
99104

100105
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
101106
def test_jumpstart_proprietary_models_cache_get(mock_cache):
@@ -136,37 +141,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
136141
)
137142

138143

139-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
140-
def test_jumpstart_models_cache_get_model_specs_open_weights(mock_cache):
141-
mock_cache.get_specs = Mock()
142-
mock_cache.get_hub_model = Mock()
143-
model_id, version = "pytorch-ic-mobilenet-v2", "*"
144-
region = "us-west-2"
145-
146-
accessors.JumpStartModelsAccessor.get_model_specs(
147-
region=region, model_id=model_id, version=version
148-
)
149-
mock_cache.get_specs.assert_called_once_with(
150-
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
151-
)
152-
mock_cache.get_hub_model.assert_not_called()
153-
154-
accessors.JumpStartModelsAccessor.get_model_specs(
155-
region=region,
156-
model_id=model_id,
157-
version=version,
158-
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
159-
)
160-
mock_cache.get_hub_model.assert_called_once_with(
161-
hub_model_arn=(
162-
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
163-
)
164-
)
165-
166-
# necessary because accessors is a static module
167-
reload(accessors)
168-
169-
170144
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache")
171145
def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock):
172146
# test change of region resets cache

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def test_get_jumpstart_gated_content_bucket_override():
9999

100100

101101
def test_get_jumpstart_launched_regions_message():
102-
103102
with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}):
104103
assert (
105104
utils.get_jumpstart_launched_regions_message()
@@ -147,7 +146,6 @@ def test_get_formatted_manifest():
147146

148147

149148
def test_parse_sagemaker_version():
150-
151149
with patch("sagemaker.__version__", "1.2.3"):
152150
assert utils.parse_sagemaker_version() == "1.2.3"
153151

@@ -188,7 +186,6 @@ def test_get_sagemaker_version(patched_parse_sm_version: Mock):
188186

189187

190188
def test_is_jumpstart_model_uri():
191-
192189
assert not utils.is_jumpstart_model_uri("fdsfdsf")
193190
assert not utils.is_jumpstart_model_uri("s3://not-jumpstart-bucket/sdfsdfds")
194191
assert not utils.is_jumpstart_model_uri("some/actual/localfile")
@@ -689,7 +686,6 @@ def test_add_jumpstart_uri_tags_training():
689686

690687

691688
def test_update_inference_tags_with_jumpstart_training_script_tags():
692-
693689
random_tag_1 = {"Key": "tag-key-1", "Value": "tag-val-1"}
694690
random_tag_2 = {"Key": "tag-key-2", "Value": "tag-val-2"}
695691

@@ -750,7 +746,6 @@ def test_update_inference_tags_with_jumpstart_training_script_tags():
750746

751747

752748
def test_update_inference_tags_with_jumpstart_training_model_tags():
753-
754749
random_tag_1 = {"Key": "tag-key-1", "Value": "tag-val-1"}
755750
random_tag_2 = {"Key": "tag-key-2", "Value": "tag-val-2"}
756751

@@ -811,7 +806,6 @@ def test_update_inference_tags_with_jumpstart_training_model_tags():
811806

812807

813808
def test_update_inference_tags_with_jumpstart_training_script_tags_inference():
814-
815809
random_tag_1 = {"Key": "tag-key-1", "Value": "tag-val-1"}
816810
random_tag_2 = {"Key": "tag-key-2", "Value": "tag-val-2"}
817811

@@ -872,7 +866,6 @@ def test_update_inference_tags_with_jumpstart_training_script_tags_inference():
872866

873867

874868
def test_update_inference_tags_with_jumpstart_training_model_tags_inference():
875-
876869
random_tag_1 = {"Key": "tag-key-1", "Value": "tag-val-1"}
877870
random_tag_2 = {"Key": "tag-key-2", "Value": "tag-val-2"}
878871

@@ -975,7 +968,6 @@ def make_vulnerable_inference_spec(*largs, **kwargs):
975968

976969
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
977970
def test_jumpstart_old_model_spec(mock_get_manifest):
978-
979971
mock_get_manifest.return_value = [
980972
JumpStartModelHeader(
981973
{
@@ -1282,7 +1274,6 @@ def test_validate_model_id_and_get_type_false(
12821274
)
12831275

12841276
with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched):
1285-
12861277
self.assertFalse(utils.validate_model_id_and_get_type("dee"))
12871278
self.assertFalse(utils.validate_model_id_and_get_type(""))
12881279
self.assertFalse(utils.validate_model_id_and_get_type(None))
@@ -1478,7 +1469,6 @@ class TestJumpStartLogger(TestCase):
14781469
@patch("logging.StreamHandler.emit")
14791470
@patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER.propagate", False)
14801471
def test_logger_normal_mode(self, mocked_emit: Mock):
1481-
14821472
JUMPSTART_LOGGER.warning("Self destruct in 3...2...1...")
14831473

14841474
mocked_emit.assert_called_once()
@@ -1487,7 +1477,6 @@ def test_logger_normal_mode(self, mocked_emit: Mock):
14871477
@patch("logging.StreamHandler.emit")
14881478
@patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER.propagate", False)
14891479
def test_logger_disabled(self, mocked_emit: Mock):
1490-
14911480
JUMPSTART_LOGGER.warning("Self destruct in 3...2...1...")
14921481

14931482
mocked_emit.assert_not_called()

0 commit comments

Comments
 (0)