@@ -201,6 +201,7 @@ def deploy(
201
201
vpc_config = None ,
202
202
enable_network_isolation = False ,
203
203
model_kms_key = None ,
204
+ predictor_cls = None ,
204
205
):
205
206
"""Deploy a candidate to a SageMaker Inference Pipeline and return a Predictor
206
207
@@ -237,10 +238,15 @@ def deploy(
237
238
training cluster for distributed training. Default: False
238
239
model_kms_key (str): KMS key ARN used to encrypt the repacked
239
240
model archive file if the model is repacked
241
+ predictor_cls (callable[string, sagemaker.session.Session]): A
242
+ function to call to create a predictor (default: None). If
243
+ specified, ``deploy()`` returns the result of invoking this
244
+ function on the created endpoint name.
240
245
241
246
Returns:
242
- callable[string, sagemaker.session.Session]: Invocation of
243
- ``self.predictor_cls`` on the created endpoint name.
247
+ callable[string, sagemaker.session.Session] or ``None``:
248
+ If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on
249
+ the created endpoint name. Otherwise, ``None``.
244
250
"""
245
251
if candidate is None :
246
252
candidate_dict = self .best_candidate ()
@@ -264,6 +270,7 @@ def deploy(
264
270
vpc_config = vpc_config ,
265
271
enable_network_isolation = enable_network_isolation ,
266
272
model_kms_key = model_kms_key ,
273
+ predictor_cls = predictor_cls ,
267
274
)
268
275
269
276
def _check_problem_type_and_job_objective (self , problem_type , job_objective ):
@@ -299,6 +306,7 @@ def _deploy_inference_pipeline(
299
306
vpc_config = None ,
300
307
enable_network_isolation = False ,
301
308
model_kms_key = None ,
309
+ predictor_cls = None ,
302
310
):
303
311
"""Deploy a SageMaker Inference Pipeline.
304
312
@@ -329,6 +337,10 @@ def _deploy_inference_pipeline(
329
337
contains "SecurityGroupIds", "Subnets"
330
338
model_kms_key (str): KMS key ARN used to encrypt the repacked
331
339
model archive file if the model is repacked
340
+ predictor_cls (callable[string, sagemaker.session.Session]): A
341
+ function to call to create a predictor (default: None). If
342
+ specified, ``deploy()`` returns the result of invoking this
343
+ function on the created endpoint name.
332
344
"""
333
345
# construct Model objects
334
346
models = []
@@ -352,6 +364,7 @@ def _deploy_inference_pipeline(
352
364
pipeline = PipelineModel (
353
365
models = models ,
354
366
role = self .role ,
367
+ predictor_cls = predictor_cls ,
355
368
name = name ,
356
369
vpc_config = vpc_config ,
357
370
sagemaker_session = sagemaker_session or self .sagemaker_session ,
0 commit comments