@@ -86,9 +86,7 @@ def __init__(
86
86
"""
87
87
removed_kwargs ("content_type" , kwargs )
88
88
removed_kwargs ("accept" , kwargs )
89
- endpoint_name = renamed_kwargs (
90
- "endpoint" , "endpoint_name" , endpoint_name , kwargs
91
- )
89
+ endpoint_name = renamed_kwargs ("endpoint" , "endpoint_name" , endpoint_name , kwargs )
92
90
self .endpoint_name = endpoint_name
93
91
self .sagemaker_session = sagemaker_session or Session ()
94
92
self .serializer = serializer
@@ -135,9 +133,7 @@ def predict(
135
133
request_args = self ._create_request_args (
136
134
data , initial_args , target_model , target_variant , inference_id
137
135
)
138
- response = self .sagemaker_session .sagemaker_runtime_client .invoke_endpoint (
139
- ** request_args
140
- )
136
+ response = self .sagemaker_session .sagemaker_runtime_client .invoke_endpoint (** request_args )
141
137
return self ._handle_response (response )
142
138
143
139
def _handle_response (self , response ):
@@ -397,11 +393,7 @@ def list_monitors(self):
397
393
endpoint_name = self .endpoint_name
398
394
)
399
395
if len (monitoring_schedules_dict ["MonitoringScheduleSummaries" ]) == 0 :
400
- print (
401
- "No monitors found for endpoint. endpoint: {}" .format (
402
- self .endpoint_name
403
- )
404
- )
396
+ print ("No monitors found for endpoint. endpoint: {}" .format (self .endpoint_name ))
405
397
return []
406
398
407
399
monitors = []
@@ -443,9 +435,7 @@ def _get_model_monitor_class(self, schedule_name, monitoring_type):
443
435
"MonitoringJobDefinition"
444
436
)
445
437
if embedded_job_definition is not None : # legacy v1 schedule
446
- image_uri = embedded_job_definition ["MonitoringAppSpecification" ][
447
- "ImageUri"
448
- ]
438
+ image_uri = embedded_job_definition ["MonitoringAppSpecification" ]["ImageUri" ]
449
439
if image_uri .endswith (DEFAULT_REPOSITORY_NAME ):
450
440
clazz = DefaultModelMonitor
451
441
else :
@@ -483,9 +473,7 @@ def endpoint_context(self):
483
473
484
474
# list context by source uri using arn
485
475
contexts = list (
486
- EndpointContext .list (
487
- sagemaker_session = self .sagemaker_session , source_uri = endpoint_arn
488
- )
476
+ EndpointContext .list (sagemaker_session = self .sagemaker_session , source_uri = endpoint_arn )
489
477
)
490
478
491
479
if len (contexts ) != 0 :
@@ -512,10 +500,8 @@ def _get_model_names(self):
512
500
if self ._model_names is not None :
513
501
return self ._model_names
514
502
current_endpoint_config_name = self ._get_endpoint_config_name ()
515
- endpoint_config = (
516
- self .sagemaker_session .sagemaker_client .describe_endpoint_config (
517
- EndpointConfigName = current_endpoint_config_name
518
- )
503
+ endpoint_config = self .sagemaker_session .sagemaker_client .describe_endpoint_config (
504
+ EndpointConfigName = current_endpoint_config_name
519
505
)
520
506
production_variants = endpoint_config ["ProductionVariants" ]
521
507
self ._model_names = [d ["ModelName" ] for d in production_variants ]
0 commit comments