Skip to content

Commit cb66608

Browse files
author
Malav Shastri
committed
fix: Address nits and improvements
1 parent 6b9f390 commit cb66608

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""This module contains accessors related to SageMaker JumpStart."""
1515
from __future__ import absolute_import
1616
import functools
17+
import logging
1718
from typing import Any, Dict, List, Optional
1819
import boto3
1920

@@ -289,15 +290,6 @@ def get_model_specs(
289290

290291
if hub_arn:
291292
try:
292-
hub_model_arn = construct_hub_model_arn_from_inputs(
293-
hub_arn=hub_arn, model_name=model_id, version=version
294-
)
295-
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
296-
hub_model_arn=hub_model_arn
297-
)
298-
model_specs.set_hub_content_type(HubContentType.MODEL)
299-
return model_specs
300-
except: # noqa: E722
301293
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
302294
hub_arn=hub_arn, model_name=model_id, version=version
303295
)
@@ -307,6 +299,17 @@ def get_model_specs(
307299
model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE)
308300
return model_specs
309301

302+
except Exception as ex:
303+
logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex))
304+
hub_model_arn = construct_hub_model_arn_from_inputs(
305+
hub_arn=hub_arn, model_name=model_id, version=version
306+
)
307+
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
308+
hub_model_arn=hub_model_arn
309+
)
310+
model_specs.set_hub_content_type(HubContentType.MODEL)
311+
return model_specs
312+
310313
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
311314
model_id=model_id, version_str=version, model_type=model_type
312315
)

src/sagemaker/jumpstart/factory/model.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
JUMPSTART_LOGGER,
3535
)
3636
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
37-
from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs
37+
from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs
3838
from sagemaker.model_metrics import ModelMetrics
3939
from sagemaker.metadata_properties import MetadataProperties
4040
from sagemaker.drift_check_baselines import DriftCheckBaselines
@@ -550,7 +550,19 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
550550
)
551551

552552
if kwargs.hub_arn:
553-
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn)
553+
if kwargs.model_reference_arn:
554+
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
555+
kwargs.hub_arn,
556+
kwargs.model_id,
557+
kwargs.model_version
558+
)
559+
else:
560+
hub_content_arn = construct_hub_model_arn_from_inputs(
561+
kwargs.hub_arn,
562+
kwargs.model_id,
563+
kwargs.model_version
564+
)
565+
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
554566

555567
return kwargs
556568

src/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,12 @@ def add_jumpstart_model_id_version_tags(
384384

385385
def add_hub_content_arn_tags(
386386
tags: Optional[List[TagsDict]],
387-
hub_arn: str,
387+
hub_content_arn: str,
388388
) -> Optional[List[TagsDict]]:
389389
"""Adds custom Hub arn tag to JumpStart related resources."""
390390

391391
tags = add_single_jumpstart_tag(
392-
hub_arn,
392+
hub_content_arn,
393393
enums.JumpStartTag.HUB_CONTENT_ARN,
394394
tags,
395395
is_uri=False,

0 commit comments

Comments
 (0)