Skip to content

Commit 15b2890

Browse files
author
Jonathan Makunga
committed
Refactoring
1 parent e08e1c0 commit 15b2890

File tree

4 files changed

+176
-226
lines changed

4 files changed

+176
-226
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 57 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,23 @@
3030
LocalModelOutOfMemoryException,
3131
LocalModelInvocationException,
3232
LocalModelLoadException,
33-
SkipTuningComboException
33+
SkipTuningComboException,
3434
)
3535
from sagemaker.serve.utils.predictors import (
3636
DjlLocalModePredictor,
3737
TgiLocalModePredictor,
3838
)
39-
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb, _get_available_gpus
39+
from sagemaker.serve.utils.local_hardware import (
40+
_get_nb_instance,
41+
_get_ram_usage_mb,
42+
_get_available_gpus,
43+
)
4044
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
4145
from sagemaker.serve.utils.tuning import (
4246
_serial_benchmark,
4347
_concurrent_benchmark,
4448
_more_performant,
45-
_pretty_print_benchmark_results
49+
_pretty_print_benchmark_results,
4650
)
4751
from sagemaker.serve.utils.types import ModelServer
4852
from sagemaker.base_predictor import PredictorBase
@@ -150,10 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
150154
model_data=self.pysdk_model.model_data,
151155
)
152156
elif not hasattr(self, "prepared_for_tgi"):
153-
(
154-
self.js_model_config,
155-
self.prepared_for_tgi
156-
) = prepare_tgi_js_resources(
157+
(self.js_model_config, self.prepared_for_tgi) = prepare_tgi_js_resources(
157158
model_path=self.model_path,
158159
js_id=self.model,
159160
dependencies=self.dependencies,
@@ -241,10 +242,7 @@ def _build_for_tgi_jumpstart(self):
241242
env = {}
242243
if self.mode == Mode.LOCAL_CONTAINER:
243244
if not hasattr(self, "prepared_for_tgi"):
244-
(
245-
self.js_model_config,
246-
self.prepared_for_tgi
247-
) = prepare_tgi_js_resources(
245+
(self.js_model_config, self.prepared_for_tgi) = prepare_tgi_js_resources(
248246
model_path=self.model_path,
249247
js_id=self.model,
250248
dependencies=self.dependencies,
@@ -256,28 +254,15 @@ def _build_for_tgi_jumpstart(self):
256254

257255
self.pysdk_model.env.update(env)
258256

259-
def _logging_debug(self, message):
260-
logging.debug("**************************************")
261-
logging.debug(message)
262-
logging.debug("**************************************")
263-
264257
def _tune_for_js(
265-
self,
266-
num_shard_env_var: str = "SM_NUM_GPUS",
267-
num_model_copies_env_var: str = "SAGEMAKER_MODEL_SERVER_WORKERS",
268-
multiple_model_copies_enabled: bool = False,
269-
max_tuning_duration: int = 1800):
270-
"""Tune for Jumpstart Models
258+
self, multiple_model_copies_enabled: bool = False, max_tuning_duration: int = 1800
259+
):
260+
"""Tune for Jumpstart Models.
271261
272262
Args:
273-
num_shard_env_var (str): The name of the environment variable specifies the
274-
number of GPUs available for training or inference. Default: ``SM_NUM_GPUS``
275-
For example SM_NUM_GPUS, or OPTION_TENSOR_PARALLEL_DEGREE.
276-
num_model_copies_env_var (str): The name of the environment variable that specifies the
277-
number of worker processes used by this model. Default: ``SAGEMAKER_MODEL_SERVER_WORKERS``
278263
multiple_model_copies_enabled (bool): Whether multiple model copies serving is enable by
279-
the Serving container. Defaults to ``False``
280-
max_tuning_duration (int): The maximum duration to deploy this ``Model`` locally.
264+
this ``DLC``. Defaults to ``False``
265+
max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
281266
Default: ``1800``
282267
returns:
283268
Tuned Model.
@@ -288,21 +273,19 @@ def _tune_for_js(
288273
)
289274
return self.pysdk_model
290275

291-
initial_env_vars = copy.deepcopy(self.pysdk_model.env)
292-
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(self.js_model_config)
293-
294-
self._logging_debug(
295-
f"initial_env_vars: {initial_env_vars},"
296-
f" admissible_tensor_parallel_degrees: {admissible_tensor_parallel_degrees}")
276+
num_shard_env_var_name = "SM_NUM_GPUS"
277+
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
278+
num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"
297279

280+
initial_env_vars = copy.deepcopy(self.pysdk_model.env)
281+
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(
282+
self.js_model_config
283+
)
298284
available_gpus = _get_available_gpus() if multiple_model_copies_enabled else None
299-
self._logging_debug(
300-
f"multiple_model_copies_enabled: {multiple_model_copies_enabled}, available_gpus: {available_gpus}")
301285

302286
benchmark_results = {}
303287
best_tuned_combination = None
304288
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
305-
306289
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
307290
if datetime.now() > timeout:
308291
logger.info("Max tuning duration reached. Tuning stopped.")
@@ -312,25 +295,15 @@ def _tune_for_js(
312295
if multiple_model_copies_enabled:
313296
sagemaker_model_server_workers = int(available_gpus / tensor_parallel_degree)
314297

315-
self.pysdk_model.env.update({
316-
num_shard_env_var: str(tensor_parallel_degree),
317-
num_model_copies_env_var: str(sagemaker_model_server_workers)
318-
})
319-
320-
self._logging_debug(
321-
f"num_shard_env_var: {num_shard_env_var}, tensor_parallel_degree: {tensor_parallel_degree}")
322-
self._logging_debug(f"sagemaker_model_server_workers: {sagemaker_model_server_workers}")
323-
self._logging_debug(
324-
f"num_model_copies_env_var: {num_model_copies_env_var}, "
325-
f"sagemaker_model_server_workers: {sagemaker_model_server_workers}")
326-
self._logging_debug(f"current model.env: {self.pysdk_model.env}")
298+
self.pysdk_model.env.update(
299+
{
300+
num_shard_env_var_name: str(tensor_parallel_degree),
301+
"SAGEMAKER_MODEL_SERVER_WORKERS": str(sagemaker_model_server_workers),
302+
}
303+
)
327304

328305
try:
329-
self._logging_debug(f"Deploying model with env config {self.pysdk_model.env}")
330-
predictor = self.pysdk_model.deploy(
331-
model_data_download_timeout=max_tuning_duration
332-
)
333-
self._logging_debug(f"Deployed model with env config {self.pysdk_model.env}")
306+
predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration)
334307

335308
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
336309
predictor, self.schema_builder.sample_input
@@ -378,91 +351,77 @@ def _tune_for_js(
378351
best_tuned_combination = tuned_configuration
379352
except LocalDeepPingException as e:
380353
logger.warning(
381-
"Deployment unsuccessful with %s: %s. %s: %s"
354+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s"
382355
"Failed to invoke the model server: %s",
383-
num_shard_env_var,
356+
num_shard_env_var_name,
384357
tensor_parallel_degree,
385-
num_model_copies_env_var,
386358
sagemaker_model_server_workers,
387359
str(e),
388360
)
389361
break
390362
except LocalModelOutOfMemoryException as e:
391363
logger.warning(
392-
f"Deployment unsuccessful with %s: %s. %s: %s. "
364+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
393365
"Out of memory when loading the model: %s",
394-
num_shard_env_var,
366+
num_shard_env_var_name,
395367
tensor_parallel_degree,
396-
num_model_copies_env_var,
397368
sagemaker_model_server_workers,
398369
str(e),
399370
)
400371
break
401372
except LocalModelInvocationException as e:
402373
logger.warning(
403-
f"Deployment unsuccessful with %s: %s. %s: %s. "
374+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
404375
"Failed to invoke the model server: %s"
405376
"Please check that model server configurations are as expected "
406377
"(Ex. serialization, deserialization, content_type, accept).",
407-
num_shard_env_var,
378+
num_shard_env_var_name,
408379
tensor_parallel_degree,
409-
num_model_copies_env_var,
410380
sagemaker_model_server_workers,
411381
str(e),
412382
)
413383
break
414384
except LocalModelLoadException as e:
415385
logger.warning(
416-
f"Deployment unsuccessful with %s: %s. %s: %s. "
386+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
417387
"Failed to load the model: %s.",
418-
num_shard_env_var,
388+
num_shard_env_var_name,
419389
tensor_parallel_degree,
420-
num_model_copies_env_var,
421390
sagemaker_model_server_workers,
422391
str(e),
423392
)
424393
break
425394
except SkipTuningComboException as e:
426395
logger.warning(
427-
f"Deployment with %s: %s. %s: %s. "
396+
"Deployment with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
428397
"was expected to be successful. However failed with: %s. "
429398
"Trying next combination.",
430-
num_shard_env_var,
399+
num_shard_env_var_name,
431400
tensor_parallel_degree,
432-
num_model_copies_env_var,
433401
sagemaker_model_server_workers,
434402
str(e),
435403
)
436404
break
437-
except Exception as e:
438-
logger.exception(e)
405+
except Exception: # pylint: disable=W0703
439406
logger.exception(
440-
f"Deployment unsuccessful with %s: %s. %s: %s. "
407+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
441408
"with uncovered exception",
442-
num_shard_env_var,
409+
num_shard_env_var_name,
443410
tensor_parallel_degree,
444-
num_model_copies_env_var,
445411
sagemaker_model_server_workers,
446412
)
447413
break
448414

449415
if best_tuned_combination:
450-
model_env_vars = [num_shard_env_var]
451-
452-
self.pysdk_model.env.update({
453-
num_shard_env_var: str(best_tuned_combination[1])
454-
})
455-
if multiple_model_copies_enabled:
456-
self.pysdk_model.env.update({
457-
num_model_copies_env_var: str(best_tuned_combination[2])
458-
})
459-
model_env_vars.append(num_model_copies_env_var)
460-
461-
self._logging_debug(f"benchmark_results: {benchmark_results}, model_env_vars: {model_env_vars}")
416+
self.pysdk_model.env.update(
417+
{
418+
num_shard_env_var_name: str(best_tuned_combination[1]),
419+
"SAGEMAKER_MODEL_SERVER_WORKERS": str(best_tuned_combination[2]),
420+
}
421+
)
462422

463423
_pretty_print_benchmark_results(
464-
benchmark_results,
465-
model_env_vars
424+
benchmark_results, [num_shard_env_var_name, "SAGEMAKER_MODEL_SERVER_WORKERS"]
466425
)
467426
logger.info(
468427
"Model Configuration: %s was most performant with avg latency: %s, "
@@ -480,32 +439,26 @@ def _tune_for_js(
480439
logger.debug(
481440
"Failed to gather any tuning results. "
482441
"Please inspect the stack trace emitted from live logging for more details. "
483-
"Falling back to default model environment variable configurations: %s",
442+
"Falling back to default model configurations: %s",
484443
self.pysdk_model.env,
485444
)
486445

487446
return self.pysdk_model
488447

489448
@_capture_telemetry("djl_jumpstart.tune")
490449
def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
491-
"""Tune for Jumpstart Models with DJL serving"""
492-
num_shard_env_var = "SM_NUM_GPUS"
493-
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
494-
num_shard_env_var = "OPTION_TENSOR_PARALLEL_DEGREE"
495-
450+
"""Tune for Jumpstart Models with DJL DLC"""
496451
return self._tune_for_js(
497-
num_shard_env_var=num_shard_env_var,
498-
multiple_model_copies_enabled=True,
499-
max_tuning_duration=max_tuning_duration
452+
multiple_model_copies_enabled=True, max_tuning_duration=max_tuning_duration
500453
)
501454

502455
@_capture_telemetry("tgi_jumpstart.tune")
503456
def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
504-
"""Tune for Jumpstart Models with TGI serving"""
457+
"""Tune for Jumpstart Models with TGI DLC"""
505458
return self._tune_for_js(
506459
# Currently, TGI does not enable multiple model copies serving.
507460
multiple_model_copies_enabled=False,
508-
max_tuning_duration=max_tuning_duration
461+
max_tuning_duration=max_tuning_duration,
509462
)
510463

511464
def _build_for_jumpstart(self):
@@ -528,6 +481,8 @@ def _build_for_jumpstart(self):
528481
self.image_uri = self.pysdk_model.image_uri
529482

530483
self._build_for_djl_jumpstart()
484+
485+
self.pysdk_model.tune = self.tune_for_djl_jumpstart
531486
elif "tgi-inference" in image_uri:
532487
logger.info("Building for TGI JumpStart Model ID...")
533488
self.model_server = ModelServer.TGI
@@ -536,13 +491,11 @@ def _build_for_jumpstart(self):
536491
self.image_uri = self.pysdk_model.image_uri
537492

538493
self._build_for_tgi_jumpstart()
494+
495+
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
539496
else:
540497
raise ValueError(
541498
"JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
542499
)
543500

544-
if self.model_server == ModelServer.TGI:
545-
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
546-
elif self.model_server == ModelServer.DJL_SERVING:
547-
self.pysdk_model.tune = self.tune_for_djl_jumpstart
548501
return self.pysdk_model

src/sagemaker/serve/utils/tuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _pretty_print_benchmark_results(results: dict, model_env_vars=None):
131131
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
132132
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
133133
"StandardDeviationResponse (Concurrent)": standard_deviations,
134-
**__env_var_data
134+
**__env_var_data,
135135
}
136136
)
137137
logger.info(

tests/integ/sagemaker/serve/test_serve_tune_js.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@
1919
from sagemaker.serve.builder.schema_builder import SchemaBuilder
2020
from tests.integ import lock
2121
from tests.integ.sagemaker.serve.constants import (
22-
SERVE_LOCAL_CONTAINER_TUNE_TIMEOUT,
2322
PYTHON_VERSION_IS_NOT_310,
2423
)
2524

26-
from tests.integ.timeout import timeout
27-
from tests.integ.utils import cleanup_model_resources
2825
import logging
2926

3027
logger = logging.getLogger(__name__)
@@ -46,11 +43,11 @@ def model_builder(sagemaker_local_session):
4643
)
4744

4845

49-
# @pytest.mark.skipif(
50-
# PYTHON_VERSION_IS_NOT_310,
51-
# reason="The goal of these tests are to test the serving components of our feature",
52-
# )
53-
# @pytest.mark.local_mode
46+
@pytest.mark.skipif(
47+
PYTHON_VERSION_IS_NOT_310,
48+
reason="The goal of these tests are to test the serving components of our feature",
49+
)
50+
@pytest.mark.local_mode
5451
def test_happy_tgi_sagemaker_endpoint(model_builder, gpu_instance_type):
5552
logger.info("Running in LOCAL_CONTAINER mode...")
5653
caught_ex = None
@@ -59,16 +56,10 @@ def test_happy_tgi_sagemaker_endpoint(model_builder, gpu_instance_type):
5956
# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
6057
with lock.lock():
6158
try:
62-
print("************************************")
63-
logger.info("Deploying and predicting in LOCAL_CONTAINER mode...")
64-
predictor = predictor = model.deploy(instance_type=gpu_instance_type)
65-
logger.info("Endpoint successfully deployed.")
66-
print("************************************END")
59+
predictor = predictor = model.deploy(instance_type="local")
6760
except Exception as e:
6861
caught_ex = e
6962
finally:
7063
predictor.delete_endpoint()
7164
if caught_ex:
7265
raise caught_ex
73-
74-

0 commit comments

Comments
 (0)