Skip to content

Commit a787ecb

Browse files
committed
Reformat python code
1 parent d8fa214 commit a787ecb

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

src/sagemaker/djl_inference/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@
4242
"ml.g5",
4343
"ml.p3",
4444
"ml.p4",
45-
}
45+
}

src/sagemaker/djl_inference/model.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from enum import Enum
1919
from typing import Optional, Union, Dict
2020

21-
import defaults
21+
from sagemaker.djl_inference import defaults
2222
import sagemaker
2323
from sagemaker import s3, Predictor, image_uris, fw_utils
2424
from sagemaker.deserializers import JSONDeserializer
@@ -94,9 +94,11 @@ def __init__(
9494
**kwargs,
9595
):
9696
if kwargs.get("model_data") is not None:
97-
raise ValueError("DJLLargeModels do not support the model_data parameter. Please use"
98-
"uncompressed_model_data and ensure the s3 uri points to a folder containing"
99-
"all model artifacts, not a tar.gz file")
97+
raise ValueError(
98+
"DJLLargeModels do not support the model_data parameter. Please use"
99+
"uncompressed_model_data and ensure the s3 uri points to a folder containing"
100+
"all model artifacts, not a tar.gz file"
101+
)
100102
super(DJLLargeModel, self).__init__(
101103
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
102104
)
@@ -126,6 +128,7 @@ def package_for_edge(
126128
tags=None,
127129
):
128130
raise NotImplementedError("DJLLargeModels do not support Sagemaker Edge")
131+
129132
def compile(
130133
self,
131134
target_instance_family,
@@ -142,7 +145,9 @@ def compile(
142145
target_platform_accelerator=None,
143146
compiler_options=None,
144147
):
145-
raise NotImplementedError("DJLLargeModels do not currently support compilation with SageMaker Neo")
148+
raise NotImplementedError(
149+
"DJLLargeModels do not currently support compilation with SageMaker Neo"
150+
)
146151

147152
def deploy(
148153
self,
@@ -168,14 +173,18 @@ def deploy(
168173
if serverless_inference_config:
169174
raise ValueError("DJLLargeModels do not support Serverless Deployment")
170175
if instance_type is None and not self.inference_recommender_job_results:
171-
raise ValueError(f"instance_type must be specified, or inference recommendation from right_size()"
172-
"must be run to deploy the model. Supported instance type families are :"
173-
f"{defaults.ALLOWED_INSTANCE_FAMILIES}")
176+
raise ValueError(
177+
f"instance_type must be specified, or inference recommendation from right_size()"
178+
"must be run to deploy the model. Supported instance type families are :"
179+
f"{defaults.ALLOWED_INSTANCE_FAMILIES}"
180+
)
174181
if instance_type:
175-
instance_family = instance_type.rsplit('.', 1)[0]
182+
instance_family = instance_type.rsplit(".", 1)[0]
176183
if not instance_family in defaults.ALLOWED_INSTANCE_FAMILIES:
177-
raise ValueError(f"Invalid instance type. DJLLargeModels only support deployment to instances"
178-
f"with GPUs. Supported instance families are {defaults.ALLOWED_INSTANCE_FAMILIES}")
184+
raise ValueError(
185+
f"Invalid instance type. DJLLargeModels only support deployment to instances"
186+
f"with GPUs. Supported instance families are {defaults.ALLOWED_INSTANCE_FAMILIES}"
187+
)
179188

180189
super(DJLLargeModel, self).deploy(
181190
initial_instance_count=initial_instance_count,
@@ -253,8 +262,8 @@ def prepare_container_def(
253262
)
254263

255264
def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
256-
serving_properties["engine"] = self.engine[0]
257-
serving_properties["option.entryPoint"] = self.engine[1]
265+
serving_properties["engine"] = self.engine.value[0]
266+
serving_properties["option.entryPoint"] = self.engine.value[1]
258267
serving_properties["option.s3url"] = self.uncompressed_model_data
259268
if self.tensor_parallel_degree:
260269
serving_properties["option.tensor_parallel_degree"] = self.tensor_parallel_degree

0 commit comments

Comments
 (0)