Skip to content

Commit d0ea680

Browse files
authored
Merge pull request #2 from makungaj1/js_tune
Tune Support for JumpStart Models.
2 parents ca9ae23 + c8289ff commit d0ea680

File tree

6 files changed

+783
-6
lines changed

6 files changed

+783
-6
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 234 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,41 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
import copy
1617
from abc import ABC, abstractmethod
18+
from datetime import datetime, timedelta
1719
from typing import Type
1820
import logging
1921

2022
from sagemaker.model import Model
2123
from sagemaker import model_uris
2224
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
25+
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
2326
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
2427
from sagemaker.serve.mode.function_pointers import Mode
28+
from sagemaker.serve.utils.exceptions import (
29+
LocalDeepPingException,
30+
LocalModelOutOfMemoryException,
31+
LocalModelInvocationException,
32+
LocalModelLoadException,
33+
SkipTuningComboException,
34+
)
2535
from sagemaker.serve.utils.predictors import (
2636
DjlLocalModePredictor,
2737
TgiLocalModePredictor,
2838
)
29-
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb
39+
from sagemaker.serve.utils.local_hardware import (
40+
_get_nb_instance,
41+
_get_ram_usage_mb,
42+
_get_available_gpus,
43+
)
3044
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
45+
from sagemaker.serve.utils.tuning import (
46+
_serial_benchmark,
47+
_concurrent_benchmark,
48+
_more_performant,
49+
_pretty_print_benchmark_results,
50+
)
3151
from sagemaker.serve.utils.types import ModelServer
3252
from sagemaker.base_predictor import PredictorBase
3353
from sagemaker.jumpstart.model import JumpStartModel
@@ -134,7 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
134154
model_data=self.pysdk_model.model_data,
135155
)
136156
elif not hasattr(self, "prepared_for_tgi"):
137-
self.prepared_for_tgi = prepare_tgi_js_resources(
157+
(self.js_model_config, self.prepared_for_tgi) = prepare_tgi_js_resources(
138158
model_path=self.model_path,
139159
js_id=self.model,
140160
dependencies=self.dependencies,
@@ -222,7 +242,7 @@ def _build_for_tgi_jumpstart(self):
222242
env = {}
223243
if self.mode == Mode.LOCAL_CONTAINER:
224244
if not hasattr(self, "prepared_for_tgi"):
225-
self.prepared_for_tgi = prepare_tgi_js_resources(
245+
(self.js_model_config, self.prepared_for_tgi) = prepare_tgi_js_resources(
226246
model_path=self.model_path,
227247
js_id=self.model,
228248
dependencies=self.dependencies,
@@ -234,6 +254,213 @@ def _build_for_tgi_jumpstart(self):
234254

235255
self.pysdk_model.env.update(env)
236256

257+
def _tune_for_js(
258+
self, multiple_model_copies_enabled: bool = False, max_tuning_duration: int = 1800
259+
):
260+
"""Tune for Jumpstart Models.
261+
262+
Args:
263+
multiple_model_copies_enabled (bool): Whether multiple model copies serving is enable by
264+
this ``DLC``. Defaults to ``False``
265+
max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
266+
Default: ``1800``
267+
returns:
268+
Tuned Model.
269+
"""
270+
if self.mode != Mode.LOCAL_CONTAINER:
271+
logger.warning(
272+
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
273+
)
274+
return self.pysdk_model
275+
276+
num_shard_env_var_name = "SM_NUM_GPUS"
277+
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
278+
num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"
279+
280+
initial_env_vars = copy.deepcopy(self.pysdk_model.env)
281+
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(
282+
self.js_model_config
283+
)
284+
available_gpus = _get_available_gpus() if multiple_model_copies_enabled else None
285+
286+
benchmark_results = {}
287+
best_tuned_combination = None
288+
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
289+
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
290+
if datetime.now() > timeout:
291+
logger.info("Max tuning duration reached. Tuning stopped.")
292+
break
293+
294+
sagemaker_model_server_workers = 1
295+
if multiple_model_copies_enabled:
296+
sagemaker_model_server_workers = int(available_gpus / tensor_parallel_degree)
297+
298+
self.pysdk_model.env.update(
299+
{
300+
num_shard_env_var_name: str(tensor_parallel_degree),
301+
"SAGEMAKER_MODEL_SERVER_WORKERS": str(sagemaker_model_server_workers),
302+
}
303+
)
304+
305+
try:
306+
predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration)
307+
308+
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
309+
predictor, self.schema_builder.sample_input
310+
)
311+
throughput_per_second, standard_deviation = _concurrent_benchmark(
312+
predictor, self.schema_builder.sample_input
313+
)
314+
315+
tested_env = self.pysdk_model.env.copy()
316+
logger.info(
317+
"Average latency: %s, throughput/s: %s for configuration: %s",
318+
avg_latency,
319+
throughput_per_second,
320+
tested_env,
321+
)
322+
benchmark_results[avg_latency] = [
323+
tested_env,
324+
p90,
325+
avg_tokens_per_second,
326+
throughput_per_second,
327+
standard_deviation,
328+
]
329+
330+
if not best_tuned_combination:
331+
best_tuned_combination = [
332+
avg_latency,
333+
tensor_parallel_degree,
334+
sagemaker_model_server_workers,
335+
p90,
336+
avg_tokens_per_second,
337+
throughput_per_second,
338+
standard_deviation,
339+
]
340+
else:
341+
tuned_configuration = [
342+
avg_latency,
343+
tensor_parallel_degree,
344+
sagemaker_model_server_workers,
345+
p90,
346+
avg_tokens_per_second,
347+
throughput_per_second,
348+
standard_deviation,
349+
]
350+
if _more_performant(best_tuned_combination, tuned_configuration):
351+
best_tuned_combination = tuned_configuration
352+
except LocalDeepPingException as e:
353+
logger.warning(
354+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s"
355+
"Failed to invoke the model server: %s",
356+
num_shard_env_var_name,
357+
tensor_parallel_degree,
358+
sagemaker_model_server_workers,
359+
str(e),
360+
)
361+
break
362+
except LocalModelOutOfMemoryException as e:
363+
logger.warning(
364+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
365+
"Out of memory when loading the model: %s",
366+
num_shard_env_var_name,
367+
tensor_parallel_degree,
368+
sagemaker_model_server_workers,
369+
str(e),
370+
)
371+
break
372+
except LocalModelInvocationException as e:
373+
logger.warning(
374+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
375+
"Failed to invoke the model server: %s"
376+
"Please check that model server configurations are as expected "
377+
"(Ex. serialization, deserialization, content_type, accept).",
378+
num_shard_env_var_name,
379+
tensor_parallel_degree,
380+
sagemaker_model_server_workers,
381+
str(e),
382+
)
383+
break
384+
except LocalModelLoadException as e:
385+
logger.warning(
386+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
387+
"Failed to load the model: %s.",
388+
num_shard_env_var_name,
389+
tensor_parallel_degree,
390+
sagemaker_model_server_workers,
391+
str(e),
392+
)
393+
break
394+
except SkipTuningComboException as e:
395+
logger.warning(
396+
"Deployment with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
397+
"was expected to be successful. However failed with: %s. "
398+
"Trying next combination.",
399+
num_shard_env_var_name,
400+
tensor_parallel_degree,
401+
sagemaker_model_server_workers,
402+
str(e),
403+
)
404+
break
405+
except Exception: # pylint: disable=W0703
406+
logger.exception(
407+
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
408+
"with uncovered exception",
409+
num_shard_env_var_name,
410+
tensor_parallel_degree,
411+
sagemaker_model_server_workers,
412+
)
413+
break
414+
415+
if best_tuned_combination:
416+
self.pysdk_model.env.update(
417+
{
418+
num_shard_env_var_name: str(best_tuned_combination[1]),
419+
"SAGEMAKER_MODEL_SERVER_WORKERS": str(best_tuned_combination[2]),
420+
}
421+
)
422+
423+
_pretty_print_benchmark_results(
424+
benchmark_results, [num_shard_env_var_name, "SAGEMAKER_MODEL_SERVER_WORKERS"]
425+
)
426+
logger.info(
427+
"Model Configuration: %s was most performant with avg latency: %s, "
428+
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
429+
"standard deviation of request %s",
430+
self.pysdk_model.env,
431+
best_tuned_combination[0],
432+
best_tuned_combination[3],
433+
best_tuned_combination[4],
434+
best_tuned_combination[5],
435+
best_tuned_combination[6],
436+
)
437+
else:
438+
self.pysdk_model.env.update(initial_env_vars)
439+
logger.debug(
440+
"Failed to gather any tuning results. "
441+
"Please inspect the stack trace emitted from live logging for more details. "
442+
"Falling back to default model configurations: %s",
443+
self.pysdk_model.env,
444+
)
445+
446+
return self.pysdk_model
447+
448+
@_capture_telemetry("djl_jumpstart.tune")
449+
def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
450+
"""Tune for Jumpstart Models with DJL DLC"""
451+
return self._tune_for_js(
452+
multiple_model_copies_enabled=True, max_tuning_duration=max_tuning_duration
453+
)
454+
455+
@_capture_telemetry("tgi_jumpstart.tune")
456+
def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
457+
"""Tune for Jumpstart Models with TGI DLC"""
458+
return self._tune_for_js(
459+
# Currently, TGI does not enable multiple model copies serving.
460+
multiple_model_copies_enabled=False,
461+
max_tuning_duration=max_tuning_duration,
462+
)
463+
237464
def _build_for_jumpstart(self):
238465
"""Placeholder docstring"""
239466
# we do not pickle for jumpstart. set to none
@@ -254,6 +481,8 @@ def _build_for_jumpstart(self):
254481
self.image_uri = self.pysdk_model.image_uri
255482

256483
self._build_for_djl_jumpstart()
484+
485+
self.pysdk_model.tune = self.tune_for_djl_jumpstart
257486
elif "tgi-inference" in image_uri:
258487
logger.info("Building for TGI JumpStart Model ID...")
259488
self.model_server = ModelServer.TGI
@@ -262,6 +491,8 @@ def _build_for_jumpstart(self):
262491
self.image_uri = self.pysdk_model.image_uri
263492

264493
self._build_for_tgi_jumpstart()
494+
495+
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
265496
else:
266497
raise ValueError(
267498
"JumpStart Model ID was not packaged with djl-inference or tgi-inference container."

src/sagemaker/serve/model_server/tgi/prepare.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Prepare TgiModel for Deployment"""
1414

1515
from __future__ import absolute_import
16+
17+
import json
1618
import tarfile
1719
import logging
1820
from typing import List
@@ -32,7 +34,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
3234
custom_extractall_tarfile(resources, code_dir)
3335

3436

35-
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:
37+
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> tuple:
3638
"""Copy the associated JumpStart Resource into the code directory"""
3739
logger.info("Downloading JumpStart artifacts from S3...")
3840

@@ -56,7 +58,13 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bo
5658
else:
5759
raise ValueError("JumpStart model data compression format is unsupported: %s", model_data)
5860

59-
return True
61+
config_json_file = code_dir.joinpath("config.json")
62+
hf_model_config = None
63+
if config_json_file.is_file():
64+
with open(str(config_json_file)) as config_json:
65+
hf_model_config = json.load(config_json)
66+
67+
return (hf_model_config, True)
6068

6169

6270
def _create_dir_structure(model_path: str) -> tuple:
@@ -82,7 +90,7 @@ def prepare_tgi_js_resources(
8290
shared_libs: List[str] = None,
8391
dependencies: str = None,
8492
model_data: str = None,
85-
) -> bool:
93+
) -> tuple:
8694
"""Prepare serving when a JumpStart model id is given
8795
8896
Args:

src/sagemaker/serve/utils/tuning.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,51 @@ def _pretty_print_results_tgi(results: dict):
9898
)
9999

100100

101+
def _pretty_print_benchmark_results(results: dict, model_env_vars=None):
102+
"""Placeholder docstring"""
103+
if model_env_vars is None:
104+
model_env_vars = []
105+
106+
__env_var_data = {}
107+
for e in model_env_vars:
108+
__env_var_data[e] = []
109+
110+
avg_latencies = []
111+
p90s = []
112+
avg_tokens_per_seconds = []
113+
throughput_per_seconds = []
114+
standard_deviations = []
115+
ordered = collections.OrderedDict(sorted(results.items()))
116+
117+
for key, value in ordered.items():
118+
avg_latencies.append(key)
119+
p90s.append(value[1])
120+
avg_tokens_per_seconds.append(value[2])
121+
throughput_per_seconds.append(value[3])
122+
standard_deviations.append(value[4])
123+
124+
for k in __env_var_data:
125+
__env_var_data[k].append(value[0][k])
126+
127+
df = pd.DataFrame(
128+
{
129+
"AverageLatency (Serial)": avg_latencies,
130+
"P90_Latency (Serial)": p90s,
131+
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
132+
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
133+
"StandardDeviationResponse (Concurrent)": standard_deviations,
134+
**__env_var_data,
135+
}
136+
)
137+
logger.info(
138+
"\n================================================================== Benchmark "
139+
"Results ==================================================================\n%s"
140+
"\n============================================================================"
141+
"===========================================================================\n",
142+
df.to_string(),
143+
)
144+
145+
101146
def _tokens_per_second(generated_text: str, max_token_length: int, latency: float) -> int:
102147
"""Placeholder docstring"""
103148
est_tokens = (_tokens_from_chars(generated_text) + _tokens_from_words(generated_text)) / 2

tests/integ/sagemaker/serve/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
SERVE_IN_PROCESS_TIMEOUT = 5
2121
SERVE_MODEL_PACKAGE_TIMEOUT = 10
2222
SERVE_LOCAL_CONTAINER_TIMEOUT = 10
23+
SERVE_LOCAL_CONTAINER_TUNE_TIMEOUT = 15
2324
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT = 15
2425
SERVE_SAVE_TIMEOUT = 2
2526

0 commit comments

Comments
 (0)