@@ -93,6 +93,10 @@ def __init__(
93
93
predictor_cls : callable = DJLLargeModelPredictor ,
94
94
** kwargs ,
95
95
):
96
+ if kwargs .get ("model_data" ) is not None :
97
+ raise ValueError ("DJLLargeModels do not support the model_data parameter. Please use"
98
+ "uncompressed_model_data and ensure the s3 uri points to a folder containing"
99
+ "all model artifacts, not a tar.gz file" )
96
100
super (DJLLargeModel , self ).__init__ (
97
101
None , image_uri , role , entry_point , predictor_cls = predictor_cls , ** kwargs
98
102
)
@@ -140,6 +144,58 @@ def compile(
140
144
):
141
145
raise NotImplementedError ("DJLLargeModels do not currently support compilation with SageMaker Neo" )
142
146
147
+ def deploy (
148
+ self ,
149
+ initial_instance_count = None ,
150
+ instance_type = None ,
151
+ serializer = None ,
152
+ deserializer = None ,
153
+ accelerator_type = None ,
154
+ endpoint_name = None ,
155
+ tags = None ,
156
+ kms_key = None ,
157
+ wait = True ,
158
+ data_capture_config = None ,
159
+ async_inference_config = None ,
160
+ serverless_inference_config = None ,
161
+ volume_size = None ,
162
+ model_data_download_timeout = None ,
163
+ container_startup_health_check_timeout = None ,
164
+ ** kwargs ,
165
+ ):
166
+ if accelerator_type :
167
+ raise ValueError ("DJLLargeModels do not support Elastic Inference Accelerators" )
168
+ if serverless_inference_config :
169
+ raise ValueError ("DJLLargeModels do not support Serverless Deployment" )
170
+ if instance_type is None and not self .inference_recommender_job_results :
171
+ raise ValueError (f"instance_type must be specified, or inference recommendation from right_size()"
172
+ "must be run to deploy the model. Supported instance type families are :"
173
+ f"{ defaults .ALLOWED_INSTANCE_FAMILIES } " )
174
+ if instance_type :
175
+ instance_family = instance_type .rsplit ('.' , 1 )[0 ]
176
+ if not instance_family in defaults .ALLOWED_INSTANCE_FAMILIES :
177
+ raise ValueError (f"Invalid instance type. DJLLargeModels only support deployment to instances"
178
+ f"with GPUs. Supported instance families are { defaults .ALLOWED_INSTANCE_FAMILIES } " )
179
+
180
+ super (DJLLargeModel , self ).deploy (
181
+ initial_instance_count = initial_instance_count ,
182
+ instance_type = instance_type ,
183
+ serializer = serializer ,
184
+ deserializer = deserializer ,
185
+ accelerator_type = accelerator_type ,
186
+ endpoint_name = endpoint_name ,
187
+ tags = tags ,
188
+ kms_key = kms_key ,
189
+ wait = wait ,
190
+ data_capture_config = data_capture_config ,
191
+ async_inference_config = async_inference_config ,
192
+ serverless_inference_config = serverless_inference_config ,
193
+ volume_size = volume_size ,
194
+ model_data_download_timeout = model_data_download_timeout ,
195
+ container_startup_health_check_timeout = container_startup_health_check_timeout ,
196
+ ** kwargs ,
197
+ )
198
+
143
199
def prepare_container_def (
144
200
self ,
145
201
instance_type = None ,
0 commit comments