@@ -133,6 +133,7 @@ def predict(
133
133
target_model = None ,
134
134
target_variant = None ,
135
135
inference_id = None ,
136
+ custom_attributes = None ,
136
137
):
137
138
"""Return the inference from the specified endpoint.
138
139
@@ -153,6 +154,18 @@ def predict(
153
154
model you want to host and the resources you want to deploy for hosting it.
154
155
inference_id (str): If you provide a value, it is added to the captured data
155
156
when you enable data capture on the endpoint (Default: None).
157
+ custom_attributes (str): Provides additional information about a request for an
158
+ inference submitted to a model hosted at an Amazon SageMaker endpoint.
159
+ The information is an opaque value that is forwarded verbatim. You could use this
160
+ value, for example, to provide an ID that you can use to track a request or to provide
161
+ other metadata that a service endpoint was programmed to process. The value must
162
+ consist of no more than 1024 visible US-ASCII characters.
163
+
164
+ The code in your model is responsible for setting or updating any custom attributes in
165
+ the response. If your code does not set this value in the response, an empty value is
166
+ returned. For example, if a custom attribute represents the trace ID, your model can
167
+ prepend the custom attribute with Trace ID: in your post-processing function
168
+ (Default: None).
156
169
157
170
Returns:
158
171
object: Inference for the given input. If a deserializer was specified when creating
@@ -162,7 +175,12 @@ def predict(
162
175
"""
163
176
164
177
request_args = self ._create_request_args (
165
- data , initial_args , target_model , target_variant , inference_id
178
+ data ,
179
+ initial_args ,
180
+ target_model ,
181
+ target_variant ,
182
+ inference_id ,
183
+ custom_attributes ,
166
184
)
167
185
response = self .sagemaker_session .sagemaker_runtime_client .invoke_endpoint (** request_args )
168
186
return self ._handle_response (response )
@@ -180,6 +198,7 @@ def _create_request_args(
180
198
target_model = None ,
181
199
target_variant = None ,
182
200
inference_id = None ,
201
+ custom_attributes = None ,
183
202
):
184
203
"""Placeholder docstring"""
185
204
args = dict (initial_args ) if initial_args else {}
@@ -206,6 +225,9 @@ def _create_request_args(
206
225
if inference_id :
207
226
args ["InferenceId" ] = inference_id
208
227
228
+ if custom_attributes :
229
+ args ["CustomAttributes" ] = custom_attributes
230
+
209
231
data = self .serializer .serialize (data )
210
232
211
233
args ["Body" ] = data
0 commit comments