Skip to content

Commit 4252d86

Browse files
committed
fix: linter
1 parent a73b5b2 commit 4252d86

File tree

5 files changed

+83
-82
lines changed

5 files changed

+83
-82
lines changed

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class HubContentSummary:
4949

5050

5151
def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
52-
"""Creates a single HubContentSummary from a HubContentSummary from the HubService List APIs."""
52+
"""Creates a single HubContentSummary.
53+
54+
This is based on the ListHubContent or ListHubContentVersions API response."""
5355
return HubContentSummary(
5456
hub_content_arn=hub_content_summary.get("HubContentArn"),
5557
hub_content_name=hub_content_summary.get("HubContentName"),
@@ -67,7 +69,9 @@ def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubCo
6769
def summary_list_from_list_api_response(
6870
list_hub_contents_response: Dict[str, Any]
6971
) -> List[HubContentSummary]:
70-
"""Creates a HubContentSummary list from either the ListHubContent or ListHubContentVersions API response."""
72+
"""Creates a HubContentSummary list.
73+
74+
This is based on the ListHubContent or ListHubContentVersions API response."""
7175
return list(
7276
map(
7377
summary_from_list_api_response,
@@ -78,7 +82,7 @@ def summary_list_from_list_api_response(
7882

7983
@dataclass
8084
class S3ObjectLocation:
81-
"""Helper class for S3 object references"""
85+
"""Helper class for S3 object references."""
8286

8387
bucket: str
8488
key: str

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def generate_hub_arn_for_init_kwargs(
120120
if match:
121121
hub_arn = hub_name
122122
else:
123-
hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session)
123+
hub_arn = construct_hub_arn_from_name(
124+
hub_name=hub_name,
125+
region=region,
126+
session=session
127+
)
124128
return hub_arn
125129

126130

@@ -184,7 +188,8 @@ def find_deprecated_vulnerable_flags_for_hub_content(
184188
Since tags are the same for all versions of a HubContent,
185189
these tags will map from the key to a list of versions impacted.
186190
For example, if certain public hub model versions are deprecated,
187-
this utility will return a `deprecated` tag mapped to the deprecated versions for the HubContent.
191+
this utility will return a `deprecated` tag
192+
mapped to the deprecated versions for the HubContent.
188193
"""
189194
list_versions_response = session.list_hub_content_versions(
190195
hub_name=hub_name,
@@ -244,7 +249,8 @@ def find_unsupported_flags_for_model_version(
244249
) -> List[CuratedHubUnsupportedFlag]:
245250
"""Finds relevant CuratedHubTags for a version of a JumpStart public hub model.
246251
247-
For example, if the public hub model is deprecated, this utility will return a `deprecated` tag.
252+
For example, if the public hub model is deprecated,
253+
this utility will return a `deprecated` tag.
248254
Since tags are the same for all versions of a HubContent,
249255
these tags will map from the key to a list of versions impacted.
250256
"""

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.jumpstart.curated_hub.types import (
2222
JumpStartModelInfo,
2323
S3ObjectLocation,
24-
HubContentSummary
24+
HubContentSummary,
2525
)
2626
from sagemaker.jumpstart.types import JumpStartModelSpecs
2727
from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC
@@ -209,7 +209,7 @@ def test_sync_filters_models_that_exist_in_hub(
209209
{
210210
"HubContentName": "mock-model-three-nonsense",
211211
"HubContentVersion": "1.0.2",
212-
"HubContentSearchKeywords": []
212+
"HubContentSearchKeywords": [],
213213
},
214214
{
215215
"HubContentName": "mock-model-four-huggingface",
@@ -391,7 +391,7 @@ def test_get_jumpstart_models_in_hub(mock_list_models, sagemaker_session):
391391
hub_content_type=None,
392392
document_schema_version=None,
393393
hub_content_status=None,
394-
creation_time=None
394+
creation_time=None,
395395
),
396396
HubContentSummary(
397397
hub_content_name="mock-model-four-huggingface",
@@ -404,7 +404,7 @@ def test_get_jumpstart_models_in_hub(mock_list_models, sagemaker_session):
404404
hub_content_type=None,
405405
document_schema_version=None,
406406
hub_content_status=None,
407-
creation_time=None
407+
creation_time=None,
408408
),
409409
]
410410

@@ -443,7 +443,7 @@ def test_determine_models_to_sync(sagemaker_session):
443443
hub_content_type=None,
444444
document_schema_version=None,
445445
hub_content_status=None,
446-
creation_time=None
446+
creation_time=None,
447447
),
448448
"mock-model-four-huggingface": HubContentSummary(
449449
hub_content_name="mock-model-four-huggingface",
@@ -456,8 +456,8 @@ def test_determine_models_to_sync(sagemaker_session):
456456
hub_content_type=None,
457457
document_schema_version=None,
458458
hub_content_status=None,
459-
creation_time=None
460-
)
459+
creation_time=None,
460+
),
461461
}
462462
model_one = JumpStartModelInfo("mock-model-one-huggingface", "1.2.3")
463463
model_two = JumpStartModelInfo("mock-model-two-pytorch", "1.0.2")
@@ -477,7 +477,7 @@ def test_determine_models_to_sync(sagemaker_session):
477477
hub_content_type=None,
478478
document_schema_version=None,
479479
hub_content_status=None,
480-
creation_time=None
480+
creation_time=None,
481481
),
482482
"mock-model-four-huggingface": HubContentSummary(
483483
hub_content_name="mock-model-four-huggingface",
@@ -490,8 +490,8 @@ def test_determine_models_to_sync(sagemaker_session):
490490
hub_content_type=None,
491491
document_schema_version=None,
492492
hub_content_status=None,
493-
creation_time=None
494-
)
493+
creation_time=None,
494+
),
495495
}
496496

497497
# No model_one, newer model_two
@@ -510,7 +510,7 @@ def test_determine_models_to_sync(sagemaker_session):
510510
hub_content_type=None,
511511
document_schema_version=None,
512512
hub_content_status=None,
513-
creation_time=None
513+
creation_time=None,
514514
),
515515
"mock-model-two-pytorch": HubContentSummary(
516516
hub_content_name="mock-model-two-pytorch",
@@ -523,8 +523,8 @@ def test_determine_models_to_sync(sagemaker_session):
523523
hub_content_type=None,
524524
document_schema_version=None,
525525
hub_content_status=None,
526-
creation_time=None
527-
)
526+
creation_time=None,
527+
),
528528
}
529529
# Same model_one, same model_two
530530
res = hub._determine_models_to_sync([model_one, model_two], js_model_map)
@@ -542,7 +542,7 @@ def test_determine_models_to_sync(sagemaker_session):
542542
hub_content_type=None,
543543
document_schema_version=None,
544544
hub_content_status=None,
545-
creation_time=None
545+
creation_time=None,
546546
),
547547
"mock-model-two-pytorch": HubContentSummary(
548548
hub_content_name="mock-model-two-pytorch",
@@ -555,8 +555,8 @@ def test_determine_models_to_sync(sagemaker_session):
555555
hub_content_type=None,
556556
document_schema_version=None,
557557
hub_content_status=None,
558-
creation_time=None
559-
)
558+
creation_time=None,
559+
),
560560
}
561561
# Old model_one, same model_two
562562
res = hub._determine_models_to_sync([model_one, model_two], js_model_map)

tests/unit/sagemaker/jumpstart/curated_hub/test_syncrequest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def _helper_generate_fileinfos(
3535
last_updated: Optional[datetime] = None,
3636
dependecy_type: Optional[HubContentDependencyType] = None,
3737
) -> List[FileInfo]:
38-
3938
file_infos = []
4039
for i in range(num_infos):
4140
bucket = bucket or "default-bucket"

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 51 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
from sagemaker.jumpstart.enums import JumpStartScriptScope
1919
from sagemaker.jumpstart.curated_hub import utils
2020
from unittest.mock import patch
21-
from sagemaker.jumpstart.curated_hub.types import (
22-
CuratedHubUnsupportedFlag,
23-
HubContentSummary
24-
)
21+
from sagemaker.jumpstart.curated_hub.types import CuratedHubUnsupportedFlag, HubContentSummary
2522
from sagemaker.jumpstart.types import HubContentType
2623

2724

@@ -211,10 +208,7 @@ def test_find_tags_for_jumpstart_model_version(mock_spec_util):
211208
mock_spec_util.return_value = mock_specs
212209

213210
tags = utils.find_unsupported_flags_for_model_version(
214-
model_id="test",
215-
version="test",
216-
region="test",
217-
session=mock_sagemaker_session
211+
model_id="test", version="test", region="test", session=mock_sagemaker_session
218212
)
219213

220214
mock_spec_util.assert_called_once_with(
@@ -230,7 +224,7 @@ def test_find_tags_for_jumpstart_model_version(mock_spec_util):
230224
assert tags == [
231225
CuratedHubUnsupportedFlag.DEPRECATED_VERSIONS,
232226
CuratedHubUnsupportedFlag.INFERENCE_VULNERABLE_VERSIONS,
233-
CuratedHubUnsupportedFlag.TRAINING_VULNERABLE_VERSIONS
227+
CuratedHubUnsupportedFlag.TRAINING_VULNERABLE_VERSIONS,
234228
]
235229

236230

@@ -244,10 +238,7 @@ def test_find_tags_for_jumpstart_model_version_some_false(mock_spec_util):
244238
mock_spec_util.return_value = mock_specs
245239

246240
tags = utils.find_unsupported_flags_for_model_version(
247-
model_id="test",
248-
version="test",
249-
region="test",
250-
session=mock_sagemaker_session
241+
model_id="test", version="test", region="test", session=mock_sagemaker_session
251242
)
252243

253244
mock_spec_util.assert_called_once_with(
@@ -273,10 +264,7 @@ def test_find_tags_for_jumpstart_model_version_all_false(mock_spec_util):
273264
mock_spec_util.return_value = mock_specs
274265

275266
tags = utils.find_unsupported_flags_for_model_version(
276-
model_id="test",
277-
version="test",
278-
region="test",
279-
session=mock_sagemaker_session
267+
model_id="test", version="test", region="test", session=mock_sagemaker_session
280268
)
281269

282270
mock_spec_util.assert_called_once_with(
@@ -302,19 +290,16 @@ def test_find_all_tags_for_jumpstart_model_filters_non_jumpstart_models(mock_spe
302290
"HubContentSearchKeywords": [
303291
"@jumpstart-model-id:model-one-pytorch",
304292
"@jumpstart-model-version:1.0.3",
305-
]
293+
],
306294
},
307295
{
308296
"HubContentVersion": "2.0.0",
309297
"HubContentSearchKeywords": [
310298
"@jumpstart-model-id:model-four-huggingface",
311299
"@jumpstart-model-version:2.0.2",
312-
]
300+
],
313301
},
314-
{
315-
"HubContentVersion": "3.0.0",
316-
"HubContentSearchKeywords": []
317-
}
302+
{"HubContentVersion": "3.0.0", "HubContentSearchKeywords": []},
318303
]
319304
}
320305

@@ -325,22 +310,28 @@ def test_find_all_tags_for_jumpstart_model_filters_non_jumpstart_models(mock_spe
325310
mock_spec_util.return_value = mock_specs
326311

327312
tags = utils.find_deprecated_vulnerable_flags_for_hub_content(
328-
hub_name="test",
329-
hub_content_name="test",
330-
region="test",
331-
session=mock_sagemaker_session
313+
hub_name="test", hub_content_name="test", region="test", session=mock_sagemaker_session
332314
)
333315

334316
mock_sagemaker_session.list_hub_content_versions.assert_called_once_with(
335317
hub_name="test",
336-
hub_content_type='Model',
318+
hub_content_type="Model",
337319
hub_content_name="test",
338320
)
339321

340322
assert tags == [
341-
{"Key": CuratedHubUnsupportedFlag.DEPRECATED_VERSIONS.value, "Value": str(["1.0.0", "2.0.0"])},
342-
{"Key": CuratedHubUnsupportedFlag.INFERENCE_VULNERABLE_VERSIONS.value, "Value": str(["1.0.0", "2.0.0"])},
343-
{"Key": CuratedHubUnsupportedFlag.TRAINING_VULNERABLE_VERSIONS.value, "Value": str(["1.0.0", "2.0.0"])}
323+
{
324+
"Key": CuratedHubUnsupportedFlag.DEPRECATED_VERSIONS.value,
325+
"Value": str(["1.0.0", "2.0.0"]),
326+
},
327+
{
328+
"Key": CuratedHubUnsupportedFlag.INFERENCE_VULNERABLE_VERSIONS.value,
329+
"Value": str(["1.0.0", "2.0.0"]),
330+
},
331+
{
332+
"Key": CuratedHubUnsupportedFlag.TRAINING_VULNERABLE_VERSIONS.value,
333+
"Value": str(["1.0.0", "2.0.0"]),
334+
},
344335
]
345336

346337

@@ -356,7 +347,7 @@ def test_summary_from_list_api_response(mock_spec_util):
356347
"HubContentStatus": "test",
357348
"HubContentDescription": "test_description",
358349
"HubContentSearchKeywords": ["test"],
359-
"CreationTime": "test_creation"
350+
"CreationTime": "test_creation",
360351
}
361352
)
362353

@@ -375,32 +366,33 @@ def test_summary_from_list_api_response(mock_spec_util):
375366

376367
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
377368
def test_summaries_from_list_api_response(mock_spec_util):
378-
test = utils.summary_list_from_list_api_response({
379-
"HubContentSummaries": [
380-
{
381-
"HubContentArn": "test",
382-
"HubContentName": "test",
383-
"HubContentVersion": "test",
384-
"HubContentType": "Model",
385-
"DocumentSchemaVersion": "test",
386-
"HubContentStatus": "test",
387-
"HubContentDescription": "test",
388-
"HubContentSearchKeywords": ["test", "test_2"],
389-
"CreationTime": "test"
390-
},
391-
{
392-
"HubContentArn": "test_2",
393-
"HubContentName": "test_2",
394-
"HubContentVersion": "test_2",
395-
"HubContentType": "Model",
396-
"DocumentSchemaVersion": "test_2",
397-
"HubContentStatus": "test_2",
398-
"HubContentDescription": "test_2",
399-
"HubContentSearchKeywords": ["test_2", "test_2_2"],
400-
"CreationTime": "test_2"
401-
}
402-
]
403-
}
369+
test = utils.summary_list_from_list_api_response(
370+
{
371+
"HubContentSummaries": [
372+
{
373+
"HubContentArn": "test",
374+
"HubContentName": "test",
375+
"HubContentVersion": "test",
376+
"HubContentType": "Model",
377+
"DocumentSchemaVersion": "test",
378+
"HubContentStatus": "test",
379+
"HubContentDescription": "test",
380+
"HubContentSearchKeywords": ["test", "test_2"],
381+
"CreationTime": "test",
382+
},
383+
{
384+
"HubContentArn": "test_2",
385+
"HubContentName": "test_2",
386+
"HubContentVersion": "test_2",
387+
"HubContentType": "Model",
388+
"DocumentSchemaVersion": "test_2",
389+
"HubContentStatus": "test_2",
390+
"HubContentDescription": "test_2",
391+
"HubContentSearchKeywords": ["test_2", "test_2_2"],
392+
"CreationTime": "test_2",
393+
},
394+
]
395+
}
404396
)
405397

406398
assert test == [
@@ -425,5 +417,5 @@ def test_summaries_from_list_api_response(mock_spec_util):
425417
hub_content_status="test_2",
426418
creation_time="test_2",
427419
hub_content_search_keywords=["test_2", "test_2_2"],
428-
)
420+
),
429421
]

0 commit comments

Comments
 (0)