Skip to content

Commit be5c379

Browse files
committed
fix: format to pass black-check
1 parent 3b1df31 commit be5c379

File tree

1 file changed

+7
-21
lines changed

1 file changed

+7
-21
lines changed

src/sagemaker/predictor.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ def __init__(
8686
"""
8787
removed_kwargs("content_type", kwargs)
8888
removed_kwargs("accept", kwargs)
89-
endpoint_name = renamed_kwargs(
90-
"endpoint", "endpoint_name", endpoint_name, kwargs
91-
)
89+
endpoint_name = renamed_kwargs("endpoint", "endpoint_name", endpoint_name, kwargs)
9290
self.endpoint_name = endpoint_name
9391
self.sagemaker_session = sagemaker_session or Session()
9492
self.serializer = serializer
@@ -135,9 +133,7 @@ def predict(
135133
request_args = self._create_request_args(
136134
data, initial_args, target_model, target_variant, inference_id
137135
)
138-
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
139-
**request_args
140-
)
136+
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
141137
return self._handle_response(response)
142138

143139
def _handle_response(self, response):
@@ -397,11 +393,7 @@ def list_monitors(self):
397393
endpoint_name=self.endpoint_name
398394
)
399395
if len(monitoring_schedules_dict["MonitoringScheduleSummaries"]) == 0:
400-
print(
401-
"No monitors found for endpoint. endpoint: {}".format(
402-
self.endpoint_name
403-
)
404-
)
396+
print("No monitors found for endpoint. endpoint: {}".format(self.endpoint_name))
405397
return []
406398

407399
monitors = []
@@ -443,9 +435,7 @@ def _get_model_monitor_class(self, schedule_name, monitoring_type):
443435
"MonitoringJobDefinition"
444436
)
445437
if embedded_job_definition is not None: # legacy v1 schedule
446-
image_uri = embedded_job_definition["MonitoringAppSpecification"][
447-
"ImageUri"
448-
]
438+
image_uri = embedded_job_definition["MonitoringAppSpecification"]["ImageUri"]
449439
if image_uri.endswith(DEFAULT_REPOSITORY_NAME):
450440
clazz = DefaultModelMonitor
451441
else:
@@ -483,9 +473,7 @@ def endpoint_context(self):
483473

484474
# list context by source uri using arn
485475
contexts = list(
486-
EndpointContext.list(
487-
sagemaker_session=self.sagemaker_session, source_uri=endpoint_arn
488-
)
476+
EndpointContext.list(sagemaker_session=self.sagemaker_session, source_uri=endpoint_arn)
489477
)
490478

491479
if len(contexts) != 0:
@@ -512,10 +500,8 @@ def _get_model_names(self):
512500
if self._model_names is not None:
513501
return self._model_names
514502
current_endpoint_config_name = self._get_endpoint_config_name()
515-
endpoint_config = (
516-
self.sagemaker_session.sagemaker_client.describe_endpoint_config(
517-
EndpointConfigName=current_endpoint_config_name
518-
)
503+
endpoint_config = self.sagemaker_session.sagemaker_client.describe_endpoint_config(
504+
EndpointConfigName=current_endpoint_config_name
519505
)
520506
production_variants = endpoint_config["ProductionVariants"]
521507
self._model_names = [d["ModelName"] for d in production_variants]

0 commit comments

Comments
 (0)