Skip to content

Commit 09c05f4

Browse files
committed
chore: use hosting_model_package_arns
1 parent 2d5f5ec commit 09c05f4

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _retrieve_model_package_arn(
5050
an exception if the version of the model is deprecated. (Default: False).
5151
5252
Returns:
53-
list: the model package arn to use for the model or None.
53+
str: the model package arn to use for the model or None.
5454
"""
5555

5656
if region is None:
@@ -66,6 +66,12 @@ def _retrieve_model_package_arn(
6666
)
6767

6868
if scope == JumpStartScriptScope.INFERENCE:
69-
return model_specs.hosting_model_package_arn
69+
70+
if model_specs.hosting_model_package_arns is None:
71+
return None
72+
73+
regional_arn = model_specs.hosting_model_package_arns.get(region)
74+
75+
return regional_arn
7076

7177
raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
352352
"training_enable_network_isolation",
353353
"resource_name_base",
354354
"hosting_eula_key",
355-
"hosting_model_package_arn",
355+
"hosting_model_package_arns",
356356
]
357357

358358
def __init__(self, spec: Dict[str, Any]):
@@ -423,7 +423,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
423423

424424
self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")
425425

426-
self.hosting_model_package_arn: Optional[str] = json_obj.get("hosting_model_package_arn")
426+
self.hosting_model_package_arns: Optional[Dict] = json_obj.get("hosting_model_package_arns")
427427

428428
if self.training_supported:
429429
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2299,7 +2299,7 @@
22992299
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
23002300
"training_prepacked_script_key": None,
23012301
"hosting_prepacked_artifact_key": None,
2302-
"hosting_model_package_arn": None,
2302+
"hosting_model_package_arns": None,
23032303
"hosting_eula_key": False,
23042304
"hyperparameters": [
23052305
{

0 commit comments

Comments
 (0)