Skip to content

Commit d8fa214

Browse files
committed
Add deploy method with validations on configs and instance types
1 parent e2a9b70 commit d8fa214

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

src/sagemaker/djl_inference/defaults.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,10 @@
3636
"roberta",
3737
"bert",
3838
}
39+
40+
ALLOWED_INSTANCE_FAMILIES = {
41+
"ml.g4",
42+
"ml.g5",
43+
"ml.p3",
44+
"ml.p4",
45+
}

src/sagemaker/djl_inference/model.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def __init__(
9393
predictor_cls: callable = DJLLargeModelPredictor,
9494
**kwargs,
9595
):
96+
if kwargs.get("model_data") is not None:
97+
raise ValueError("DJLLargeModels do not support the model_data parameter. Please use"
98+
"uncompressed_model_data and ensure the s3 uri points to a folder containing"
99+
"all model artifacts, not a tar.gz file")
96100
super(DJLLargeModel, self).__init__(
97101
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
98102
)
@@ -140,6 +144,58 @@ def compile(
140144
):
141145
raise NotImplementedError("DJLLargeModels do not currently support compilation with SageMaker Neo")
142146

147+
def deploy(
148+
self,
149+
initial_instance_count=None,
150+
instance_type=None,
151+
serializer=None,
152+
deserializer=None,
153+
accelerator_type=None,
154+
endpoint_name=None,
155+
tags=None,
156+
kms_key=None,
157+
wait=True,
158+
data_capture_config=None,
159+
async_inference_config=None,
160+
serverless_inference_config=None,
161+
volume_size=None,
162+
model_data_download_timeout=None,
163+
container_startup_health_check_timeout=None,
164+
**kwargs,
165+
):
166+
if accelerator_type:
167+
raise ValueError("DJLLargeModels do not support Elastic Inference Accelerators")
168+
if serverless_inference_config:
169+
raise ValueError("DJLLargeModels do not support Serverless Deployment")
170+
if instance_type is None and not self.inference_recommender_job_results:
171+
raise ValueError(f"instance_type must be specified, or inference recommendation from right_size()"
172+
"must be run to deploy the model. Supported instance type families are :"
173+
f"{defaults.ALLOWED_INSTANCE_FAMILIES}")
174+
if instance_type:
175+
instance_family = instance_type.rsplit('.', 1)[0]
176+
if not instance_family in defaults.ALLOWED_INSTANCE_FAMILIES:
177+
raise ValueError(f"Invalid instance type. DJLLargeModels only support deployment to instances"
178+
f"with GPUs. Supported instance families are {defaults.ALLOWED_INSTANCE_FAMILIES}")
179+
180+
super(DJLLargeModel, self).deploy(
181+
initial_instance_count=initial_instance_count,
182+
instance_type=instance_type,
183+
serializer=serializer,
184+
deserializer=deserializer,
185+
accelerator_type=accelerator_type,
186+
endpoint_name=endpoint_name,
187+
tags=tags,
188+
kms_key=kms_key,
189+
wait=wait,
190+
data_capture_config=data_capture_config,
191+
async_inference_config=async_inference_config,
192+
serverless_inference_config=serverless_inference_config,
193+
volume_size=volume_size,
194+
model_data_download_timeout=model_data_download_timeout,
195+
container_startup_health_check_timeout=container_startup_health_check_timeout,
196+
**kwargs,
197+
)
198+
143199
def prepare_container_def(
144200
self,
145201
instance_type=None,

0 commit comments

Comments
 (0)