Skip to content

Commit 882773e

Browse files
committed
chore: modify metadata keys, add unit test for custom attributes
1 parent 01ff3e5 commit 882773e

File tree

5 files changed

+34
-9
lines changed

5 files changed

+34
-9
lines changed

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from sagemaker.jumpstart.utils import (
2020
verify_model_region_and_return_specs,
2121
)
22+
from sagemaker.jumpstart.enums import (
23+
JumpStartScriptScope,
24+
)
2225

2326

2427
def _retrieve_model_package_arn(
@@ -62,4 +65,7 @@ def _retrieve_model_package_arn(
6265
tolerate_deprecated_model=tolerate_deprecated_model,
6366
)
6467

65-
return model_specs.model_package_arn
68+
if scope == JumpStartScriptScope.INFERENCE:
69+
return model_specs.hosting_model_package_arn
70+
71+
raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
351351
"inference_enable_network_isolation",
352352
"training_enable_network_isolation",
353353
"resource_name_base",
354-
"eula_key",
355-
"model_package_arn",
354+
"hosting_eula_key",
355+
"hosting_model_package_arn",
356356
]
357357

358358
def __init__(self, spec: Dict[str, Any]):
@@ -421,9 +421,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
421421
)
422422
self.resource_name_base: bool = json_obj.get("resource_name_base")
423423

424-
self.eula_key: Optional[str] = json_obj.get("eula_key")
424+
self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")
425425

426-
self.model_package_arn: Optional[str] = json_obj.get("model_package_arn")
426+
self.hosting_model_package_arn: Optional[str] = json_obj.get("hosting_model_package_arn")
427427

428428
if self.training_supported:
429429
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

src/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,13 +402,13 @@ def verify_model_region_and_return_specs(
402402
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
403403
)
404404

405-
if model_specs.eula_key:
405+
if model_specs.hosting_eula_key and scope == constants.JumpStartScriptScope.INFERENCE.value:
406406
LOGGER.info(
407407
"Using model with end-user license agreement (EULA). "
408408
"See https://%s.s3.%s.amazonaws.com/%s for terms of use.",
409409
get_jumpstart_content_bucket(region=region),
410410
region,
411-
model_specs.eula_key,
411+
model_specs.hosting_eula_key,
412412
)
413413

414414
if model_specs.deprecated:

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,8 +2299,8 @@
22992299
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
23002300
"training_prepacked_script_key": None,
23012301
"hosting_prepacked_artifact_key": None,
2302-
"model_package_arn": None,
2303-
"eula_key": False,
2302+
"hosting_model_package_arn": None,
2303+
"hosting_eula_key": False,
23042304
"hyperparameters": [
23052305
{
23062306
"name": "epochs",

tests/unit/test_predictor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,3 +614,22 @@ def test_setting_serializer_deserializer_atts_changes_content_accept_types():
614614
predictor.deserializer = PandasDeserializer()
615615
assert predictor.accept == ("text/csv", "application/json")
616616
assert predictor.content_type == "text/csv"
617+
618+
619+
def test_custom_attributes():
620+
sagemaker_session = empty_sagemaker_session()
621+
predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)
622+
623+
sagemaker_session.sagemaker_runtime_client.invoke_endpoint = Mock(
624+
return_value={"Body": io.StringIO("response")}
625+
)
626+
627+
predictor.predict("payload", custom_attributes="custom-attribute")
628+
629+
sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(
630+
EndpointName=ENDPOINT,
631+
ContentType="application/octet-stream",
632+
Accept="*/*",
633+
CustomAttributes="custom-attribute",
634+
Body="payload",
635+
)

0 commit comments

Comments
 (0)