18
18
from enum import Enum
19
19
from typing import Optional , Union , Dict
20
20
21
- import defaults
21
+ from sagemaker . djl_inference import defaults
22
22
import sagemaker
23
23
from sagemaker import s3 , Predictor , image_uris , fw_utils
24
24
from sagemaker .deserializers import JSONDeserializer
@@ -94,9 +94,11 @@ def __init__(
94
94
** kwargs ,
95
95
):
96
96
if kwargs .get ("model_data" ) is not None :
97
- raise ValueError ("DJLLargeModels do not support the model_data parameter. Please use"
98
- "uncompressed_model_data and ensure the s3 uri points to a folder containing"
99
- "all model artifacts, not a tar.gz file" )
97
+ raise ValueError (
98
+ "DJLLargeModels do not support the model_data parameter. Please use"
99
+ "uncompressed_model_data and ensure the s3 uri points to a folder containing"
100
+ "all model artifacts, not a tar.gz file"
101
+ )
100
102
super (DJLLargeModel , self ).__init__ (
101
103
None , image_uri , role , entry_point , predictor_cls = predictor_cls , ** kwargs
102
104
)
@@ -126,6 +128,7 @@ def package_for_edge(
126
128
tags = None ,
127
129
):
128
130
raise NotImplementedError ("DJLLargeModels do not support Sagemaker Edge" )
131
+
129
132
def compile (
130
133
self ,
131
134
target_instance_family ,
@@ -142,7 +145,9 @@ def compile(
142
145
target_platform_accelerator = None ,
143
146
compiler_options = None ,
144
147
):
145
- raise NotImplementedError ("DJLLargeModels do not currently support compilation with SageMaker Neo" )
148
+ raise NotImplementedError (
149
+ "DJLLargeModels do not currently support compilation with SageMaker Neo"
150
+ )
146
151
147
152
def deploy (
148
153
self ,
@@ -168,14 +173,18 @@ def deploy(
168
173
if serverless_inference_config :
169
174
raise ValueError ("DJLLargeModels do not support Serverless Deployment" )
170
175
if instance_type is None and not self .inference_recommender_job_results :
171
- raise ValueError (f"instance_type must be specified, or inference recommendation from right_size()"
172
- "must be run to deploy the model. Supported instance type families are :"
173
- f"{ defaults .ALLOWED_INSTANCE_FAMILIES } " )
176
+ raise ValueError (
177
+ f"instance_type must be specified, or inference recommendation from right_size()"
178
+ "must be run to deploy the model. Supported instance type families are :"
179
+ f"{ defaults .ALLOWED_INSTANCE_FAMILIES } "
180
+ )
174
181
if instance_type :
175
- instance_family = instance_type .rsplit ('.' , 1 )[0 ]
182
+ instance_family = instance_type .rsplit ("." , 1 )[0 ]
176
183
if not instance_family in defaults .ALLOWED_INSTANCE_FAMILIES :
177
- raise ValueError (f"Invalid instance type. DJLLargeModels only support deployment to instances"
178
- f"with GPUs. Supported instance families are { defaults .ALLOWED_INSTANCE_FAMILIES } " )
184
+ raise ValueError (
185
+ f"Invalid instance type. DJLLargeModels only support deployment to instances"
186
+ f"with GPUs. Supported instance families are { defaults .ALLOWED_INSTANCE_FAMILIES } "
187
+ )
179
188
180
189
super (DJLLargeModel , self ).deploy (
181
190
initial_instance_count = initial_instance_count ,
@@ -253,8 +262,8 @@ def prepare_container_def(
253
262
)
254
263
255
264
def generate_serving_properties (self , serving_properties = {}) -> Dict [str , str ]:
256
- serving_properties ["engine" ] = self .engine [0 ]
257
- serving_properties ["option.entryPoint" ] = self .engine [1 ]
265
+ serving_properties ["engine" ] = self .engine . value [0 ]
266
+ serving_properties ["option.entryPoint" ] = self .engine . value [1 ]
258
267
serving_properties ["option.s3url" ] = self .uncompressed_model_data
259
268
if self .tensor_parallel_degree :
260
269
serving_properties ["option.tensor_parallel_degree" ] = self .tensor_parallel_degree
0 commit comments