Skip to content

Commit c957e71

Browse files
committed
fix: Adding more alterations
1 parent 395e99c commit c957e71

File tree

5 files changed

+204
-59
lines changed

5 files changed

+204
-59
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,19 @@
5252
get_jumpstart_model_and_version,
5353
find_jumpstart_tags_for_hub_content,
5454
get_latest_version_for_model,
55-
summary_list_from_list_api_response
55+
summary_list_from_list_api_response,
56+
generate_unique_hub_content_model_name
5657
)
5758
from sagemaker.jumpstart.curated_hub.types import (
5859
HubContentDocument_v2,
5960
JumpStartModelInfo,
6061
S3ObjectLocation,
6162
CuratedHubTag,
6263
HubContentSummary,
64+
CuratedHubModelInfo,
6365
)
6466

67+
LIST_HUB_CACHE = None
6568

6669
class CuratedHub:
6770
"""Class for creating and managing a curated JumpStart hub"""
@@ -152,16 +155,17 @@ def describe(self) -> DescribeHubResponse:
152155
return hub_description
153156

154157

155-
def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]:
158+
def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]:
156159
"""Lists the models in this Curated Hub
157160
158161
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
159162
"""
160163
if clear_cache:
161-
self._list_models.cache_clear()
164+
LIST_HUB_CACHE = None
165+
if LIST_HUB_CACHE:
166+
return LIST_HUB_CACHE
162167
return self._list_models(**kwargs)
163168

164-
@lru_cache(maxsize=5)
165169
def _list_models(self, **kwargs) -> Dict[str, Any]:
166170
"""Lists the models in this Curated Hub
167171
@@ -212,13 +216,12 @@ def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool
212216
return False
213217

214218
def _get_jumpstart_models_in_hub(self) -> List[HubContentSummary]:
215-
"""Returns list of `HubContent` that have been created from a JumpStart model."""
216-
hub_models: List[HubContentSummary] = summary_list_from_list_api_response(self.list_models())
219+
hub_models = summary_list_from_list_api_response(self.list_models())
217220
return [model for model in hub_models if get_jumpstart_model_and_version(model) is not None]
218221

219222
def _determine_models_to_sync(
220-
self, model_to_sync: List[JumpStartModelInfo], models_in_hub: List[JumpStartModelInfo]
221-
) -> List[JumpStartModelInfo]:
223+
self, model_to_sync: List[JumpStartModelInfo], models_in_hub: List[CuratedHubModelInfo]
224+
) -> List[CuratedHubModelInfo]:
222225
"""Determines which models from `sync` params to sync into the CuratedHub.
223226
224227
Algorithm:
@@ -229,12 +232,22 @@ def _determine_models_to_sync(
229232
in Hub, don't sync. If newer version in Hub, don't sync. If older version in Hub,
230233
sync that model.
231234
"""
232-
jumpstart_model_id_to_model_info_map = {model.model_id: model for model in models_in_hub}
233-
models_to_sync: List[JumpStartModelInfo] = []
234-
for model in model_to_sync:
235-
matched_model = jumpstart_model_id_to_model_info_map.get(model.model_id)
236-
if not matched_model or Version(matched_model.version) < Version(model.version):
237-
models_to_sync.append(model)
235+
hub_content_jumpstart_model_id_map = {model.jumpstart_model_info.model_id: model for model in models_in_hub}
236+
models_to_sync: List[CuratedHubModelInfo] = []
237+
for public_hub_model in model_to_sync:
238+
matched_model = hub_content_jumpstart_model_id_map.get(public_hub_model.model_id)
239+
if not matched_model:
240+
models_to_sync.append(CuratedHubModelInfo(
241+
jumpstart_model_info=public_hub_model,
242+
hub_content_model_id=generate_unique_hub_content_model_name(public_hub_model.model_id),
243+
hub_content_version="*"
244+
))
245+
elif Version(matched_model.jumpstart_model_info.version) < Version(public_hub_model.version):
246+
models_to_sync.append(CuratedHubModelInfo(
247+
jumpstart_model_info=public_hub_model,
248+
hub_content_model_id=matched_model.hub_content_model_id,
249+
hub_content_version="*"
250+
))
238251

239252
return models_to_sync
240253

@@ -265,12 +278,17 @@ def sync(self, model_list: List[Dict[str, str]]):
265278
)
266279
model_version_list.append(JumpStartModelInfo(model_id, version))
267280

268-
js_models_in_hub = self._get_jumpstart_models_in_hub()
269-
js_models_in_hub = [get_jumpstart_model_and_version(model) for model in js_models_in_hub]
270-
271-
models_to_sync = self._determine_models_to_sync(model_version_list, js_models_in_hub)
281+
jumpstart_models_in_hub = self._get_jumpstart_models_in_hub()
282+
curated_models = [
283+
CuratedHubModelInfo(
284+
jumpstart_model_id=get_jumpstart_model_and_version(model),
285+
hub_content_model_id=model.hub_content_name,
286+
hub_content_version=model.hub_content_version
287+
) for model in jumpstart_models_in_hub
288+
]
289+
models_to_sync = self._determine_models_to_sync(model_version_list, curated_models)
272290
JUMPSTART_LOGGER.warning(
273-
"Syncing the following models into Hub %s: %s", self.hub_name, models_to_sync
291+
"Syncing the following models into Hub %s: %s", self.hub_name, [model.jumpstart_model_info for model in models_to_sync]
274292
)
275293

276294
# Delete old models?
@@ -304,11 +322,11 @@ def sync(self, model_list: List[Dict[str, str]]):
304322
f"Failures when importing models to curated hub in parallel: {failed_imports}"
305323
)
306324

307-
def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
325+
def _sync_public_model_to_hub(self, model: CuratedHubModelInfo, thread_num: int):
308326
"""Syncs a public JumpStart model version to the Hub. Runs in parallel."""
309327
model_specs = utils.verify_model_region_and_return_specs(
310-
model_id=model.model_id,
311-
version=model.version,
328+
model_id=model.jumpstart_model_info.model_id,
329+
version=model.jumpstart_model_info.version,
312330
region=self.region,
313331
scope=JumpStartScriptScope.INFERENCE,
314332
sagemaker_session=self._sagemaker_session,
@@ -317,7 +335,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
317335

318336
dest_location = S3ObjectLocation(
319337
bucket=self.hub_storage_location.bucket,
320-
key=f"{self.hub_storage_location.key}/curated_models/{model.model_id}/{model.version}",
338+
key=f"{self.hub_storage_location.key}/curated_models/{model.jumpstart_model_info.model_id}/{model.jumpstart_model_info.version}",
321339
)
322340
src_files = file_generator.generate_file_infos_from_model_specs(
323341
model_specs, studio_specs, self.region, self._s3_client
@@ -339,16 +357,16 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
339357
label=dest_location.key,
340358
).execute()
341359
else:
342-
JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version)
360+
JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.jumpstart_model_info.model_id, model.jumpstart_model_info.version)
343361

344362
# TODO: Tag model if specs say it is deprecated or training/inference
345363
# vulnerable. Update tag of HubContent ARN without version.
346364
# Versioned ARNs are not onboarded to Tagris.
347365
tags = []
348366

349367
search_keywords = [
350-
f"{JUMPSTART_HUB_MODEL_ID_TAG_PREFIX}:{model.model_id}",
351-
f"{JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX}:{model.version}",
368+
f"{JUMPSTART_HUB_MODEL_ID_TAG_PREFIX}:{model.jumpstart_model_info.model_id}",
369+
f"{JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX}:{model.jumpstart_model_info.version}",
352370
f"{FRAMEWORK_TAG_PREFIX}:{model_specs.get_framework()}",
353371
f"{TASK_TAG_PREFIX}:TODO: pull from specs",
354372
]
@@ -357,8 +375,8 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
357375

358376
self._sagemaker_session.import_hub_content(
359377
document_schema_version=HubContentDocument_v2.SCHEMA_VERSION,
360-
hub_content_name=model.model_id,
361-
hub_content_version=model.version,
378+
hub_content_name=model.hub_content_model_id,
379+
hub_content_version=model.hub_content_version,
362380
hub_name=self.hub_name,
363381
hub_content_document=hub_content_document,
364382
hub_content_type=HubContentType.MODEL,
@@ -398,9 +416,9 @@ def scan_and_tag_models(self) -> None:
398416
JUMPSTART_LOGGER.info(
399417
"Tagging models in hub: %s", self.hub_name
400418
)
401-
models_in_hub: List[HubContentSummary] = self._get_jumpstart_models_in_hub()
419+
js_models_in_hub = [model for model in self.list_models() if get_jumpstart_model_and_version(model) is not None]
402420
tags_added: Dict[str, List[CuratedHubTag]] = {}
403-
for model in models_in_hub:
421+
for model in js_models_in_hub:
404422
tags_to_add: List[CuratedHubTag] = find_jumpstart_tags_for_hub_content(
405423
hub_name=self.hub_name,
406424
hub_content_name=model.hub_content_name,

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,29 @@
2020
from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs, HubContentType
2121

2222
class CuratedHubTagName(str, Enum):
23-
"""Enum class for Curated Hub """
23+
"""Enum class for Curated Hub tag names."""
2424
DEPRECATED_VERSIONS = "deprecated_versions"
2525
TRAINING_VULNERABLE_VERSIONS = "training_vulnerable_versions"
2626
INFERENCE_VULNERABLE_VERSIONS = "inference_vulnerable_versions"
2727

2828
@dataclass
2929
class HubContentSummary:
30+
"""Dataclass to store HubContentSummary from List APIs."""
3031
hub_content_arn: str
3132
hub_content_name: str
3233
hub_content_version: str
3334
hub_content_type: HubContentType
3435
document_schema_version: str
3536
hub_content_status: str
36-
hub_content_display_name: str
37-
hub_content_description: str
38-
hub_content_search_keywords: List[str]
3937
creation_time: str
38+
hub_content_display_name: str = None
39+
hub_content_description: str = None
40+
hub_content_search_keywords: List[str] = None
41+
4042

4143
@dataclass
4244
class CuratedHubTag:
45+
"""Dataclass to store Curated Hub-specific tags."""
4346
key: CuratedHubTagName
4447
value: str
4548

@@ -62,14 +65,20 @@ def get_uri(self) -> str:
6265
"""Returns the s3 URI"""
6366
return f"s3://{self.bucket}/{self.key}"
6467

65-
6668
@dataclass
6769
class JumpStartModelInfo:
6870
"""Helper class for storing JumpStart model info."""
6971

7072
model_id: str
7173
version: str
7274

75+
@dataclass
76+
class CuratedHubModelInfo:
77+
"""Helper class to store Curated Hub model info."""
78+
jumpstart_model_info: JumpStartModelInfo
79+
hub_content_model_id: str
80+
hub_content_version: str
81+
7382

7483
class HubContentDependencyType(str, Enum):
7584
"""Enum class for HubContent dependency names"""

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TASK_TAG_PREFIX,
4141
FRAMEWORK_TAG_PREFIX,
4242
)
43+
from uuid import uuid4
4344

4445

4546
def get_info_from_hub_resource_arn(
@@ -253,16 +254,16 @@ def find_jumpstart_tags_for_model_version(model_id: str, version: str, region: s
253254
def get_jumpstart_model_and_version(hub_content_summary: HubContentSummary) -> Optional[JumpStartModelInfo]:
254255
jumpstart_model_id = next(
255256
(
256-
tag
257-
for tag in hub_content_summary.search_keywords
257+
tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX)+1:] # Need to remove the tag_prefix and ":"
258+
for tag in hub_content_summary.hub_content_search_keywords
258259
if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX)
259260
),
260261
None,
261262
)
262263
jumpstart_model_version = next(
263264
(
264-
tag
265-
for tag in hub_content_summary.search_keywords
265+
tag[len(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX)+1:]
266+
for tag in hub_content_summary.hub_content_search_keywords
266267
if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX)
267268
),
268269
None,
@@ -282,7 +283,7 @@ def get_latest_version_for_model(model_id: str, region: str) -> str:
282283

283284
def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
284285
return HubContentSummary(
285-
hub_content_arn=hub_content_summary.get("HubContentSummary"),
286+
hub_content_arn=hub_content_summary.get("HubContentArn"),
286287
hub_content_name=hub_content_summary.get("HubContentName"),
287288
hub_content_version=hub_content_summary.get("HubContentVersion"),
288289
hub_content_type=hub_content_summary.get("HubContentType"),
@@ -301,4 +302,7 @@ def tag_to_add_tags_api_call(tag: CuratedHubTag) -> Dict[str, str]:
301302
return {
302303
'Key': tag.key,
303304
'Value': tag.value
304-
}
305+
}
306+
307+
def generate_unique_hub_content_model_name(model_id: str) -> str:
308+
return f"{model_id}-{uuid4()}"

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -187,48 +187,48 @@ def test_sync_kicks_off_parallel_syncs(
187187

188188
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
189189
@patch(f"{MODULE_PATH}._sync_public_model_to_hub")
190-
@patch(f"{MODULE_PATH}.list_models")
190+
@patch(f"{MODULE_PATH}._list_models")
191191
def test_sync_filters_models_that_exist_in_hub(
192192
mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session
193193
):
194194
mock_get_model_specs.side_effect = get_spec_from_base_spec
195195
mock_list_models.return_value = {
196196
"HubContentSummaries": [
197197
{
198-
"name": "mock-model-two-pytorch",
199-
"version": "1.0.2",
200-
"search_keywords": [
201-
"@jumpstart-model-id:model-two-pytorch",
198+
"HubContentName": "mock-model-two-pytorch",
199+
"HubContentVersion": "1.0.2",
200+
"HubContentSearchKeywords": [
201+
"@jumpstart-model-id:mock-pytorch-model-already-exists-in-hub",
202202
"@jumpstart-model-version:1.0.2",
203203
],
204204
},
205-
{"name": "mock-model-three-nonsense", "version": "1.0.2", "search_keywords": []},
205+
{"HubContentName": "mock-model-three-nonsense", "HubContentVersion": "1.0.2", "HubContentSearchKeywords": []},
206206
{
207-
"name": "mock-model-four-huggingface",
208-
"version": "2.0.2",
209-
"search_keywords": [
207+
"HubContentName": "mock-model-four-huggingface",
208+
"HubContentVersion": "2.0.2",
209+
"HubContentSearchKeywords": [
210210
"@jumpstart-model-id:model-four-huggingface",
211211
"@jumpstart-model-version:2.0.2",
212212
],
213213
},
214214
]
215215
}
216216
hub_name = "mock_hub_name"
217-
model_one = {"model_id": "mock-model-one-huggingface"}
218-
model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"}
217+
model_one = {"model_id": "mock-pytorch-model-does-not-exist"}
218+
model_two = {"model_id": "mock-pytorch-model-already-exists-in-hub", "version": "1.0.2"}
219219
mock_sync_public_models.return_value = ""
220220
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
221221

222222
hub.sync([model_one, model_two])
223223

224224
mock_sync_public_models.assert_called_once_with(
225-
JumpStartModelInfo("mock-model-one-huggingface", "*"), 0
225+
JumpStartModelInfo("mock-model-does-not-exist", "*"), 0
226226
)
227227

228228

229229
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
230230
@patch(f"{MODULE_PATH}._sync_public_model_to_hub")
231-
@patch(f"{MODULE_PATH}.list_models")
231+
@patch(f"{MODULE_PATH}._list_models")
232232
def test_sync_updates_old_models_in_hub(
233233
mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session
234234
):

0 commit comments

Comments
 (0)