Skip to content

Commit 5af2b17

Browse files
committed
chore: add support for custom attributes to predictor class
1 parent 2367d16 commit 5af2b17

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

src/sagemaker/base_predictor.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def predict(
133133
target_model=None,
134134
target_variant=None,
135135
inference_id=None,
136+
custom_attributes=None,
136137
):
137138
"""Return the inference from the specified endpoint.
138139
@@ -153,6 +154,18 @@ def predict(
153154
model you want to host and the resources you want to deploy for hosting it.
154155
inference_id (str): If you provide a value, it is added to the captured data
155156
when you enable data capture on the endpoint (Default: None).
157+
custom_attributes (str): Provides additional information about a request for an
158+
inference submitted to a model hosted at an Amazon SageMaker endpoint.
159+
The information is an opaque value that is forwarded verbatim. You could use this
160+
value, for example, to provide an ID that you can use to track a request or to provide
161+
other metadata that a service endpoint was programmed to process. The value must
162+
consist of no more than 1024 visible US-ASCII characters.
163+
164+
The code in your model is responsible for setting or updating any custom attributes in
165+
the response. If your code does not set this value in the response, an empty value is
166+
returned. For example, if a custom attribute represents the trace ID, your model can
167+
prepend the custom attribute with Trace ID: in your post-processing function
168+
(Default: None).
156169
157170
Returns:
158171
object: Inference for the given input. If a deserializer was specified when creating
@@ -162,7 +175,12 @@ def predict(
162175
"""
163176

164177
request_args = self._create_request_args(
165-
data, initial_args, target_model, target_variant, inference_id
178+
data,
179+
initial_args,
180+
target_model,
181+
target_variant,
182+
inference_id,
183+
custom_attributes,
166184
)
167185
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
168186
return self._handle_response(response)
@@ -180,6 +198,7 @@ def _create_request_args(
180198
target_model=None,
181199
target_variant=None,
182200
inference_id=None,
201+
custom_attributes=None,
183202
):
184203
"""Placeholder docstring"""
185204
args = dict(initial_args) if initial_args else {}
@@ -206,6 +225,9 @@ def _create_request_args(
206225
if inference_id:
207226
args["InferenceId"] = inference_id
208227

228+
if custom_attributes:
229+
args["CustomAttributes"] = custom_attributes
230+
209231
data = self.serializer.serialize(data)
210232

211233
args["Body"] = data

src/sagemaker/jumpstart/artifacts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@
5252
_retrieve_supported_accept_types,
5353
_retrieve_supported_content_types,
5454
)
55-
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn
55+
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn # noqa: F401

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart model packages."""
1414
from __future__ import absolute_import
15-
from copy import deepcopy
16-
from typing import Dict, List, Optional
15+
from typing import Optional
1716
from sagemaker.jumpstart.constants import (
1817
JUMPSTART_DEFAULT_REGION_NAME,
1918
)

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232
from sagemaker.jumpstart.utils import is_valid_model_id
3333
from sagemaker.utils import stringify_object
34-
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model, ModelPackage
34+
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
3535
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3636
from sagemaker.predictor import PredictorBase
3737
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig

0 commit comments

Comments
 (0)