Skip to content

Commit 7cda304

Browse files
committed
fix: Add tests
1 parent 9e703a9 commit 7cda304

File tree

4 files changed

+228
-249
lines changed

4 files changed

+228
-249
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,14 @@
5151
tag_hub_content,
5252
get_jumpstart_model_and_version,
5353
find_jumpstart_tags_for_hub_content,
54-
get_latest_version_for_model,
5554
summary_list_from_list_api_response,
56-
generate_unique_hub_content_model_name
5755
)
5856
from sagemaker.jumpstart.curated_hub.types import (
5957
HubContentDocument_v2,
6058
JumpStartModelInfo,
6159
S3ObjectLocation,
6260
CuratedHubTag,
6361
HubContentSummary,
64-
CuratedHubModelInfo,
6562
)
6663

6764
LIST_HUB_CACHE = None
@@ -215,13 +212,23 @@ def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool
215212
return True
216213
return False
217214

215+
def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str]:
216+
"""Populates the lastest version of a model from specs no matter what is passed.
217+
218+
Returns model ({ model_id: str, version: str })
219+
"""
220+
model_specs = utils.verify_model_region_and_return_specs(
221+
model["model_id"], "*", JumpStartScriptScope.INFERENCE, self.region
222+
)
223+
return {"model_id": model["model_id"], "version": model_specs.version}
224+
218225
def _get_jumpstart_models_in_hub(self) -> List[HubContentSummary]:
219226
hub_models = summary_list_from_list_api_response(self.list_models())
220227
return [model for model in hub_models if get_jumpstart_model_and_version(model) is not None]
221228

222229
def _determine_models_to_sync(
223-
self, model_to_sync: List[JumpStartModelInfo], models_in_hub: List[CuratedHubModelInfo]
224-
) -> List[CuratedHubModelInfo]:
230+
self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, HubContentSummary]
231+
) -> List[JumpStartModelInfo]:
225232
"""Determines which models from `sync` params to sync into the CuratedHub.
226233
227234
Algorithm:
@@ -232,22 +239,11 @@ def _determine_models_to_sync(
232239
in Hub, don't sync. If newer version in Hub, don't sync. If older version in Hub,
233240
sync that model.
234241
"""
235-
hub_content_jumpstart_model_id_map = {model.jumpstart_model_info.model_id: model for model in models_in_hub}
236-
models_to_sync: List[CuratedHubModelInfo] = []
237-
for public_hub_model in model_to_sync:
238-
matched_model = hub_content_jumpstart_model_id_map.get(public_hub_model.model_id)
239-
if not matched_model:
240-
models_to_sync.append(CuratedHubModelInfo(
241-
jumpstart_model_info=public_hub_model,
242-
hub_content_model_id=generate_unique_hub_content_model_name(public_hub_model.model_id),
243-
hub_content_version="*"
244-
))
245-
elif Version(matched_model.jumpstart_model_info.version) < Version(public_hub_model.version):
246-
models_to_sync.append(CuratedHubModelInfo(
247-
jumpstart_model_info=public_hub_model,
248-
hub_content_model_id=matched_model.hub_content_model_id,
249-
hub_content_version="*"
250-
))
242+
models_to_sync = []
243+
for model in model_list:
244+
matched_model = models_in_hub.get(model.model_id)
245+
if not matched_model or Version(matched_model.hub_content_version) < Version(model.version):
246+
models_to_sync.append(model)
251247

252248
return models_to_sync
253249

@@ -265,30 +261,24 @@ def sync(self, model_list: List[Dict[str, str]]):
265261
)
266262

267263
# Retrieve latest version of unspecified JumpStart model versions
268-
model_version_list: List[JumpStartModelInfo] = []
264+
model_version_list = []
269265
for model in model_list:
270-
model_id = model.get("model_id")
271266
version = model.get("version", "*")
272267
if version == "*":
273-
version = get_latest_version_for_model(model_id=model_id, region=self.region)
268+
model = self._populate_latest_model_version(model)
274269
JUMPSTART_LOGGER.warning(
275270
"No version specified for model %s. Using version %s",
276-
model_id,
277-
version,
271+
model["model_id"],
272+
model["version"],
278273
)
279-
model_version_list.append(JumpStartModelInfo(model_id, version))
280-
281-
jumpstart_models_in_hub = self._get_jumpstart_models_in_hub()
282-
curated_models = [
283-
CuratedHubModelInfo(
284-
jumpstart_model_info=get_jumpstart_model_and_version(model),
285-
hub_content_model_id=model.hub_content_name,
286-
hub_content_version=model.hub_content_version
287-
) for model in jumpstart_models_in_hub
288-
]
289-
models_to_sync = self._determine_models_to_sync(model_version_list, curated_models)
274+
model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"]))
275+
276+
js_models_in_hub = self._get_jumpstart_models_in_hub()
277+
mapped_models_in_hub = {model.hub_content_name: model for model in js_models_in_hub}
278+
279+
models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub)
290280
JUMPSTART_LOGGER.warning(
291-
"Syncing the following models into Hub %s: %s", self.hub_name, [model.jumpstart_model_info for model in models_to_sync]
281+
"Syncing the following models into Hub %s: %s", self.hub_name, models_to_sync
292282
)
293283

294284
# Delete old models?
@@ -322,11 +312,11 @@ def sync(self, model_list: List[Dict[str, str]]):
322312
f"Failures when importing models to curated hub in parallel: {failed_imports}"
323313
)
324314

325-
def _sync_public_model_to_hub(self, model: CuratedHubModelInfo, thread_num: int):
315+
def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
326316
"""Syncs a public JumpStart model version to the Hub. Runs in parallel."""
327317
model_specs = utils.verify_model_region_and_return_specs(
328-
model_id=model.jumpstart_model_info.model_id,
329-
version=model.jumpstart_model_info.version,
318+
model_id=model.model_id,
319+
version=model.version,
330320
region=self.region,
331321
scope=JumpStartScriptScope.INFERENCE,
332322
sagemaker_session=self._sagemaker_session,
@@ -335,7 +325,7 @@ def _sync_public_model_to_hub(self, model: CuratedHubModelInfo, thread_num: int)
335325

336326
dest_location = S3ObjectLocation(
337327
bucket=self.hub_storage_location.bucket,
338-
key=f"{self.hub_storage_location.key}/curated_models/{model.jumpstart_model_info.model_id}/{model.jumpstart_model_info.version}",
328+
key=f"{self.hub_storage_location.key}/curated_models/{model.model_id}/{model.version}",
339329
)
340330
src_files = file_generator.generate_file_infos_from_model_specs(
341331
model_specs, studio_specs, self.region, self._s3_client
@@ -357,16 +347,16 @@ def _sync_public_model_to_hub(self, model: CuratedHubModelInfo, thread_num: int)
357347
label=dest_location.key,
358348
).execute()
359349
else:
360-
JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.jumpstart_model_info.model_id, model.jumpstart_model_info.version)
350+
JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version)
361351

362352
# TODO: Tag model if specs say it is deprecated or training/inference
363353
# vulnerable. Update tag of HubContent ARN without version.
364354
# Versioned ARNs are not onboarded to Tagris.
365355
tags = []
366356

367357
search_keywords = [
368-
f"{JUMPSTART_HUB_MODEL_ID_TAG_PREFIX}:{model.jumpstart_model_info.model_id}",
369-
f"{JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX}:{model.jumpstart_model_info.version}",
358+
f"{JUMPSTART_HUB_MODEL_ID_TAG_PREFIX}:{model.model_id}",
359+
f"{JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX}:{model.version}",
370360
f"{FRAMEWORK_TAG_PREFIX}:{model_specs.get_framework()}",
371361
f"{TASK_TAG_PREFIX}:TODO: pull from specs",
372362
]
@@ -375,8 +365,8 @@ def _sync_public_model_to_hub(self, model: CuratedHubModelInfo, thread_num: int)
375365

376366
self._sagemaker_session.import_hub_content(
377367
document_schema_version=HubContentDocument_v2.SCHEMA_VERSION,
378-
hub_content_name=model.hub_content_model_id,
379-
hub_content_version=model.hub_content_version,
368+
hub_content_name=model.model_id,
369+
hub_content_version=model.version,
380370
hub_name=self.hub_name,
381371
hub_content_document=hub_content_document,
382372
hub_content_type=HubContentType.MODEL,

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,6 @@ class JumpStartModelInfo:
7272
model_id: str
7373
version: str
7474

75-
@dataclass
76-
class CuratedHubModelInfo:
77-
"""Helper class to store Curated Hub model info."""
78-
jumpstart_model_info: JumpStartModelInfo
79-
hub_content_model_id: str
80-
hub_content_version: str
81-
8275

8376
class HubContentDependencyType(str, Enum):
8477
"""Enum class for HubContent dependency names"""

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,6 @@ def get_jumpstart_model_and_version(hub_content_summary: HubContentSummary) -> O
274274
jumpstart_model_id = jumpstart_model_id_tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX):] # Need to remove the tag_prefix and ":"
275275
jumpstart_model_version = jumpstart_model_version_tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX):]
276276
return JumpStartModelInfo(model_id=jumpstart_model_id, version=jumpstart_model_version)
277-
278-
def get_latest_version_for_model(model_id: str, region: str) -> str:
279-
"""Returns the latest version of a model from specs."""
280-
model_specs = utils.verify_model_region_and_return_specs(
281-
model_id, "*", JumpStartScriptScope.INFERENCE, region
282-
)
283-
return model_specs.version
284-
285277

286278
def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
287279
return HubContentSummary(

0 commit comments

Comments
 (0)