Skip to content

Commit f089b5a

Browse files
malav-shastriMalav Shastri
and
Malav Shastri
authored
fix: cross account private hub model fine-tuning (#4843)
Co-authored-by: Malav Shastri <[email protected]>
1 parent 647acba commit f089b5a

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,24 @@ def _retrieval_function(
462462
HubContentType.MODEL_REFERENCE,
463463
}:
464464

465-
hub_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info)
465+
hub_resource_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info)
466+
hub_arn = hub_utils.construct_hub_arn_from_name(
467+
hub_name=hub_resource_arn_extracted_info.hub_name,
468+
region=hub_resource_arn_extracted_info.region,
469+
account_id=hub_resource_arn_extracted_info.account_id,
470+
)
466471

467472
model_version: str = hub_utils.get_hub_model_version(
468-
hub_model_name=hub_arn_extracted_info.hub_content_name,
473+
hub_model_name=hub_resource_arn_extracted_info.hub_content_name,
469474
hub_model_type=data_type.value,
470-
hub_name=hub_arn_extracted_info.hub_name,
475+
hub_name=hub_arn,
471476
sagemaker_session=self._sagemaker_session,
472-
hub_model_version=hub_arn_extracted_info.hub_content_version,
477+
hub_model_version=hub_resource_arn_extracted_info.hub_content_version,
473478
)
474479

475480
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
476-
hub_name=hub_arn_extracted_info.hub_name,
477-
hub_content_name=hub_arn_extracted_info.hub_content_name,
481+
hub_name=hub_arn,
482+
hub_content_name=hub_resource_arn_extracted_info.hub_content_name,
478483
hub_content_version=model_version,
479484
hub_content_type=data_type.value,
480485
)

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ def construct_hub_arn_from_name(
6767
hub_name: str,
6868
region: Optional[str] = None,
6969
session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
70+
account_id: Optional[str] = None,
7071
) -> str:
7172
"""Constructs a Hub arn from the Hub name using default Session values."""
7273

73-
account_id = session.account_id()
74+
account_id = account_id or session.account_id()
7475
region = region or session.boto_region_name
7576
partition = aws_partition(region)
7677

0 commit comments

Comments
 (0)