23
23
from sagemaker import model_uris
24
24
from sagemaker .serve .model_server .djl_serving .prepare import prepare_djl_js_resources
25
25
from sagemaker .serve .model_server .djl_serving .utils import _get_admissible_tensor_parallel_degrees
26
+ from sagemaker .serve .model_server .multi_model_server .prepare import prepare_mms_js_resources
26
27
from sagemaker .serve .model_server .tgi .prepare import prepare_tgi_js_resources , _create_dir_structure
27
28
from sagemaker .serve .mode .function_pointers import Mode
28
29
from sagemaker .serve .utils .exceptions import (
35
36
from sagemaker .serve .utils .predictors import (
36
37
DjlLocalModePredictor ,
37
38
TgiLocalModePredictor ,
39
+ TransformersLocalModePredictor ,
38
40
)
39
41
from sagemaker .serve .utils .local_hardware import (
40
42
_get_nb_instance ,
@@ -90,6 +92,7 @@ def __init__(self):
90
92
self .existing_properties = None
91
93
self .prepared_for_tgi = None
92
94
self .prepared_for_djl = None
95
+ self .prepared_for_mms = None
93
96
self .schema_builder = None
94
97
self .nb_instance_type = None
95
98
self .ram_usage_model_load = None
@@ -137,7 +140,11 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
137
140
138
141
if overwrite_mode == Mode .SAGEMAKER_ENDPOINT :
139
142
self .mode = self .pysdk_model .mode = Mode .SAGEMAKER_ENDPOINT
140
- if not hasattr (self , "prepared_for_djl" ) or not hasattr (self , "prepared_for_tgi" ):
143
+ if (
144
+ not hasattr (self , "prepared_for_djl" )
145
+ or not hasattr (self , "prepared_for_tgi" )
146
+ or not hasattr (self , "prepared_for_mms" )
147
+ ):
141
148
self .pysdk_model .model_data , env = self ._prepare_for_mode ()
142
149
elif overwrite_mode == Mode .LOCAL_CONTAINER :
143
150
self .mode = self .pysdk_model .mode = Mode .LOCAL_CONTAINER
@@ -160,6 +167,13 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
160
167
dependencies = self .dependencies ,
161
168
model_data = self .pysdk_model .model_data ,
162
169
)
170
+ elif not hasattr (self , "prepared_for_mms" ):
171
+ self .js_model_config , self .prepared_for_mms = prepare_mms_js_resources (
172
+ model_path = self .model_path ,
173
+ js_id = self .model ,
174
+ dependencies = self .dependencies ,
175
+ model_data = self .pysdk_model .model_data ,
176
+ )
163
177
164
178
self ._prepare_for_mode ()
165
179
env = {}
@@ -179,6 +193,10 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
179
193
predictor = TgiLocalModePredictor (
180
194
self .modes [str (Mode .LOCAL_CONTAINER )], serializer , deserializer
181
195
)
196
+ elif self .model_server == ModelServer .MMS :
197
+ predictor = TransformersLocalModePredictor (
198
+ self .modes [str (Mode .LOCAL_CONTAINER )], serializer , deserializer
199
+ )
182
200
183
201
ram_usage_before = _get_ram_usage_mb ()
184
202
self .modes [str (Mode .LOCAL_CONTAINER )].create_server (
@@ -254,6 +272,24 @@ def _build_for_tgi_jumpstart(self):
254
272
255
273
self .pysdk_model .env .update (env )
256
274
275
+ def _build_for_mms_jumpstart (self ):
276
+ """Placeholder docstring"""
277
+
278
+ env = {}
279
+ if self .mode == Mode .LOCAL_CONTAINER :
280
+ if not hasattr (self , "prepared_for_mms" ):
281
+ self .js_model_config , self .prepared_for_mms = prepare_mms_js_resources (
282
+ model_path = self .model_path ,
283
+ js_id = self .model ,
284
+ dependencies = self .dependencies ,
285
+ model_data = self .pysdk_model .model_data ,
286
+ )
287
+ self ._prepare_for_mode ()
288
+ elif self .mode == Mode .SAGEMAKER_ENDPOINT and hasattr (self , "prepared_for_mms" ):
289
+ self .pysdk_model .model_data , env = self ._prepare_for_mode ()
290
+
291
+ self .pysdk_model .env .update (env )
292
+
257
293
def _tune_for_js (self , sharded_supported : bool , max_tuning_duration : int = 1800 ):
258
294
"""Tune for Jumpstart Models in Local Mode.
259
295
@@ -264,7 +300,7 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
264
300
returns:
265
301
Tuned Model.
266
302
"""
267
- if self .mode != Mode .LOCAL_CONTAINER :
303
+ if self .mode == Mode .SAGEMAKER_ENDPOINT :
268
304
logger .warning (
269
305
"Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
270
306
)
@@ -438,7 +474,6 @@ def _build_for_jumpstart(self):
438
474
self .jumpstart = True
439
475
440
476
pysdk_model = self ._create_pre_trained_js_model ()
441
-
442
477
image_uri = pysdk_model .image_uri
443
478
444
479
logger .info ("JumpStart ID %s is packaged with Image URI: %s" , self .model , image_uri )
@@ -451,7 +486,6 @@ def _build_for_jumpstart(self):
451
486
if "djl-inference" in image_uri :
452
487
logger .info ("Building for DJL JumpStart Model ID..." )
453
488
self .model_server = ModelServer .DJL_SERVING
454
-
455
489
self .pysdk_model = pysdk_model
456
490
self .image_uri = self .pysdk_model .image_uri
457
491
@@ -461,16 +495,23 @@ def _build_for_jumpstart(self):
461
495
elif "tgi-inference" in image_uri :
462
496
logger .info ("Building for TGI JumpStart Model ID..." )
463
497
self .model_server = ModelServer .TGI
464
-
465
498
self .pysdk_model = pysdk_model
466
499
self .image_uri = self .pysdk_model .image_uri
467
500
468
501
self ._build_for_tgi_jumpstart ()
469
502
470
503
self .pysdk_model .tune = self .tune_for_tgi_jumpstart
471
- else :
504
+ elif "huggingface-pytorch-inference:" in image_uri :
505
+ logger .info ("Building for MMS JumpStart Model ID..." )
506
+ self .model_server = ModelServer .MMS
507
+ self .pysdk_model = pysdk_model
508
+ self .image_uri = self .pysdk_model .image_uri
509
+
510
+ self ._build_for_mms_jumpstart ()
511
+ elif self .mode != Mode .SAGEMAKER_ENDPOINT :
472
512
raise ValueError (
473
- "JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
513
+ "JumpStart Model ID was not packaged "
514
+ "with djl-inference, tgi-inference, or mms-inference container."
474
515
)
475
516
476
517
return self .pysdk_model
0 commit comments