Skip to content

Commit 01ff3e5

Browse files
committed
chore: address PR comments, fix failing tests
1 parent 5af2b17 commit 01ff3e5

File tree

6 files changed

+21
-18
lines changed

6 files changed

+21
-18
lines changed

src/sagemaker/base_predictor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,15 @@ def predict(
157157
custom_attributes (str): Provides additional information about a request for an
158158
inference submitted to a model hosted at an Amazon SageMaker endpoint.
159159
The information is an opaque value that is forwarded verbatim. You could use this
160-
value, for example, to provide an ID that you can use to track a request or to provide
161-
other metadata that a service endpoint was programmed to process. The value must
162-
consist of no more than 1024 visible US-ASCII characters.
163-
164-
The code in your model is responsible for setting or updating any custom attributes in
165-
the response. If your code does not set this value in the response, an empty value is
166-
returned. For example, if a custom attribute represents the trace ID, your model can
167-
prepend the custom attribute with Trace ID: in your post-processing function
168-
(Default: None).
160+
value, for example, to provide an ID that you can use to track a request or to
161+
provide other metadata that a service endpoint was programmed to process. The value
162+
must consist of no more than 1024 visible US-ASCII characters.
163+
164+
The code in your model is responsible for setting or updating any custom attributes
165+
in the response. If your code does not set this value in the response, an empty
166+
value is returned. For example, if a custom attribute represents the trace ID, your
167+
model can prepend the custom attribute with Trace ID: in your post-processing
168+
function (Default: None).
169169
170170
Returns:
171171
object: Inference for the given input. If a deserializer was specified when creating

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
from sagemaker.jumpstart.constants import (
1717
JUMPSTART_DEFAULT_REGION_NAME,
1818
)
19-
from sagemaker.jumpstart.enums import (
20-
JumpStartScriptScope,
21-
)
2219
from sagemaker.jumpstart.utils import (
2320
verify_model_region_and_return_specs,
2421
)
@@ -28,6 +25,7 @@ def _retrieve_model_package_arn(
2825
model_id: str,
2926
model_version: str,
3027
region: Optional[str],
28+
scope: Optional[str] = None,
3129
tolerate_vulnerable_model: bool = False,
3230
tolerate_deprecated_model: bool = False,
3331
) -> Optional[str]:
@@ -39,6 +37,7 @@ def _retrieve_model_package_arn(
3937
model_version (str): Version of the JumpStart model for which to retrieve the
4038
model package arn.
4139
region (Optional[str]): Region for which to retrieve the model package arn.
40+
scope (Optional[str]): Scope for which to retrieve the model package arn.
4241
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4342
specifications should be tolerated (exception not raised). If False, raises an
4443
exception if the script used by this version of the model has dependencies with known
@@ -57,7 +56,7 @@ def _retrieve_model_package_arn(
5756
model_specs = verify_model_region_and_return_specs(
5857
model_id=model_id,
5958
version=model_version,
60-
scope=JumpStartScriptScope.TRAINING,
59+
scope=scope,
6160
region=region,
6261
tolerate_vulnerable_model=tolerate_vulnerable_model,
6362
tolerate_deprecated_model=tolerate_deprecated_model,

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
296296
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
297297
model_id=kwargs.model_id,
298298
model_version=kwargs.model_version,
299+
scope=JumpStartScriptScope.INFERENCE,
299300
region=kwargs.region,
300301
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
301302
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
351351
"inference_enable_network_isolation",
352352
"training_enable_network_isolation",
353353
"resource_name_base",
354-
"eula_model",
354+
"eula_key",
355355
"model_package_arn",
356356
]
357357

@@ -421,7 +421,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
421421
)
422422
self.resource_name_base: bool = json_obj.get("resource_name_base")
423423

424-
self.eula_model: bool = json_obj.get("eula_model", False)
424+
self.eula_key: Optional[str] = json_obj.get("eula_key")
425425

426426
self.model_package_arn: Optional[str] = json_obj.get("model_package_arn")
427427

src/sagemaker/jumpstart/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,13 @@ def verify_model_region_and_return_specs(
402402
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
403403
)
404404

405-
if model_specs.eula_model:
405+
if model_specs.eula_key:
406406
LOGGER.info(
407407
"Using model with end-user license agreement (EULA). "
408-
"Deploying this model requires accepting EULA terms."
408+
"See https://%s.s3.%s.amazonaws.com/%s for terms of use.",
409+
get_jumpstart_content_bucket(region=region),
410+
region,
411+
model_specs.eula_key,
409412
)
410413

411414
if model_specs.deprecated:

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2300,7 +2300,7 @@
23002300
"training_prepacked_script_key": None,
23012301
"hosting_prepacked_artifact_key": None,
23022302
"model_package_arn": None,
2303-
"eula_model": False,
2303+
"eula_key": False,
23042304
"hyperparameters": [
23052305
{
23062306
"name": "epochs",

0 commit comments

Comments
 (0)