Skip to content

Commit 951d2c4

Browse files
makungaj1Jonathan Makunga
andauthored
Tune (local mode) support for Jumpstart Models (#4532)
* JumpStart TGI tune support * TGI does not enable multiple model serving internal to the container * Tune support for JS model with DJL DLC * Pretty Print JS TGI Benchmark results * Pretty Print JS DJL Benchmark results * Debugging * Refactoring * Revert "Refactoring" This reverts commit 4be1679. * refactoring * Revert "Debugging" This reverts commit 1ec29e3. * Add Tests * Refactoring * Refactoring * Debugging * Refactoring and Debugging * Refactoring * Refactoring Refactoring Refactoring Refactoring * Refactoring * Add unit tests * Address PR Review Comments * Sharded is not supported for all model types * Refactoring * Refactoring * Refactoring * Refactoring * Refactoring * Addressed PR Review comments * Sharding support validation for TGI * Fix unit tests * Increase code coverage --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 12fe775 commit 951d2c4

File tree

7 files changed

+859
-8
lines changed

7 files changed

+859
-8
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 204 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,41 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
import copy
1617
from abc import ABC, abstractmethod
18+
from datetime import datetime, timedelta
1719
from typing import Type
1820
import logging
1921

2022
from sagemaker.model import Model
2123
from sagemaker import model_uris
2224
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
25+
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
2326
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
2427
from sagemaker.serve.mode.function_pointers import Mode
28+
from sagemaker.serve.utils.exceptions import (
29+
LocalDeepPingException,
30+
LocalModelOutOfMemoryException,
31+
LocalModelInvocationException,
32+
LocalModelLoadException,
33+
SkipTuningComboException,
34+
)
2535
from sagemaker.serve.utils.predictors import (
2636
DjlLocalModePredictor,
2737
TgiLocalModePredictor,
2838
)
29-
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb
39+
from sagemaker.serve.utils.local_hardware import (
40+
_get_nb_instance,
41+
_get_ram_usage_mb,
42+
)
3043
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
44+
from sagemaker.serve.utils.tuning import (
45+
_pretty_print_results_jumpstart,
46+
_serial_benchmark,
47+
_concurrent_benchmark,
48+
_more_performant,
49+
_sharded_supported,
50+
)
3151
from sagemaker.serve.utils.types import ModelServer
3252
from sagemaker.base_predictor import PredictorBase
3353
from sagemaker.jumpstart.model import JumpStartModel
@@ -134,7 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
134154
model_data=self.pysdk_model.model_data,
135155
)
136156
elif not hasattr(self, "prepared_for_tgi"):
137-
self.prepared_for_tgi = prepare_tgi_js_resources(
157+
self.js_model_config, self.prepared_for_tgi = prepare_tgi_js_resources(
138158
model_path=self.model_path,
139159
js_id=self.model,
140160
dependencies=self.dependencies,
@@ -222,7 +242,7 @@ def _build_for_tgi_jumpstart(self):
222242
env = {}
223243
if self.mode == Mode.LOCAL_CONTAINER:
224244
if not hasattr(self, "prepared_for_tgi"):
225-
self.prepared_for_tgi = prepare_tgi_js_resources(
245+
self.js_model_config, self.prepared_for_tgi = prepare_tgi_js_resources(
226246
model_path=self.model_path,
227247
js_id=self.model,
228248
dependencies=self.dependencies,
@@ -234,6 +254,183 @@ def _build_for_tgi_jumpstart(self):
234254

235255
self.pysdk_model.env.update(env)
236256

257+
def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800):
258+
"""Tune for Jumpstart Models in Local Mode.
259+
260+
Args:
261+
sharded_supported (bool): Indicates whether sharding is supported by this ``Model``
262+
max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
263+
Default: ``1800``
264+
returns:
265+
Tuned Model.
266+
"""
267+
if self.mode != Mode.LOCAL_CONTAINER:
268+
logger.warning(
269+
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
270+
)
271+
return self.pysdk_model
272+
273+
num_shard_env_var_name = "SM_NUM_GPUS"
274+
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
275+
num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"
276+
277+
initial_env_vars = copy.deepcopy(self.pysdk_model.env)
278+
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(
279+
self.js_model_config
280+
)
281+
282+
if len(admissible_tensor_parallel_degrees) > 1 and not sharded_supported:
283+
admissible_tensor_parallel_degrees = [1]
284+
logger.warning(
285+
"Sharding across multiple GPUs is not supported for this model. "
286+
"Model can only be sharded across [1] GPU"
287+
)
288+
289+
benchmark_results = {}
290+
best_tuned_combination = None
291+
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
292+
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
293+
if datetime.now() > timeout:
294+
logger.info("Max tuning duration reached. Tuning stopped.")
295+
break
296+
297+
self.pysdk_model.env.update({num_shard_env_var_name: str(tensor_parallel_degree)})
298+
try:
299+
logger.info("Trying tensor parallel degree: %s", tensor_parallel_degree)
300+
301+
predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration)
302+
303+
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
304+
predictor, self.schema_builder.sample_input
305+
)
306+
throughput_per_second, standard_deviation = _concurrent_benchmark(
307+
predictor, self.schema_builder.sample_input
308+
)
309+
310+
tested_env = copy.deepcopy(self.pysdk_model.env)
311+
logger.info(
312+
"Average latency: %s, throughput/s: %s for configuration: %s",
313+
avg_latency,
314+
throughput_per_second,
315+
tested_env,
316+
)
317+
benchmark_results[avg_latency] = [
318+
tested_env,
319+
p90,
320+
avg_tokens_per_second,
321+
throughput_per_second,
322+
standard_deviation,
323+
]
324+
325+
if not best_tuned_combination:
326+
best_tuned_combination = [
327+
avg_latency,
328+
tensor_parallel_degree,
329+
None,
330+
p90,
331+
avg_tokens_per_second,
332+
throughput_per_second,
333+
standard_deviation,
334+
]
335+
else:
336+
tuned_configuration = [
337+
avg_latency,
338+
tensor_parallel_degree,
339+
None,
340+
p90,
341+
avg_tokens_per_second,
342+
throughput_per_second,
343+
standard_deviation,
344+
]
345+
if _more_performant(best_tuned_combination, tuned_configuration):
346+
best_tuned_combination = tuned_configuration
347+
except LocalDeepPingException as e:
348+
logger.warning(
349+
"Deployment unsuccessful with %s: %s. " "Failed to invoke the model server: %s",
350+
num_shard_env_var_name,
351+
tensor_parallel_degree,
352+
str(e),
353+
)
354+
except LocalModelOutOfMemoryException as e:
355+
logger.warning(
356+
"Deployment unsuccessful with %s: %s. "
357+
"Out of memory when loading the model: %s",
358+
num_shard_env_var_name,
359+
tensor_parallel_degree,
360+
str(e),
361+
)
362+
except LocalModelInvocationException as e:
363+
logger.warning(
364+
"Deployment unsuccessful with %s: %s. "
365+
"Failed to invoke the model server: %s"
366+
"Please check that model server configurations are as expected "
367+
"(Ex. serialization, deserialization, content_type, accept).",
368+
num_shard_env_var_name,
369+
tensor_parallel_degree,
370+
str(e),
371+
)
372+
except LocalModelLoadException as e:
373+
logger.warning(
374+
"Deployment unsuccessful with %s: %s. " "Failed to load the model: %s.",
375+
num_shard_env_var_name,
376+
tensor_parallel_degree,
377+
str(e),
378+
)
379+
except SkipTuningComboException as e:
380+
logger.warning(
381+
"Deployment with %s: %s"
382+
"was expected to be successful. However failed with: %s. "
383+
"Trying next combination.",
384+
num_shard_env_var_name,
385+
tensor_parallel_degree,
386+
str(e),
387+
)
388+
except Exception: # pylint: disable=W0703
389+
logger.exception(
390+
"Deployment unsuccessful with %s: %s. " "with uncovered exception",
391+
num_shard_env_var_name,
392+
tensor_parallel_degree,
393+
)
394+
395+
if best_tuned_combination:
396+
self.pysdk_model.env.update({num_shard_env_var_name: str(best_tuned_combination[1])})
397+
398+
_pretty_print_results_jumpstart(benchmark_results, [num_shard_env_var_name])
399+
logger.info(
400+
"Model Configuration: %s was most performant with avg latency: %s, "
401+
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
402+
"standard deviation of request %s",
403+
self.pysdk_model.env,
404+
best_tuned_combination[0],
405+
best_tuned_combination[3],
406+
best_tuned_combination[4],
407+
best_tuned_combination[5],
408+
best_tuned_combination[6],
409+
)
410+
else:
411+
self.pysdk_model.env.update(initial_env_vars)
412+
logger.debug(
413+
"Failed to gather any tuning results. "
414+
"Please inspect the stack trace emitted from live logging for more details. "
415+
"Falling back to default model configurations: %s",
416+
self.pysdk_model.env,
417+
)
418+
419+
return self.pysdk_model
420+
421+
@_capture_telemetry("djl_jumpstart.tune")
422+
def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
423+
"""Tune for Jumpstart Models with DJL DLC"""
424+
return self._tune_for_js(sharded_supported=True, max_tuning_duration=max_tuning_duration)
425+
426+
@_capture_telemetry("tgi_jumpstart.tune")
427+
def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
428+
"""Tune for Jumpstart Models with TGI DLC"""
429+
sharded_supported = _sharded_supported(self.model, self.js_model_config)
430+
return self._tune_for_js(
431+
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
432+
)
433+
237434
def _build_for_jumpstart(self):
238435
"""Placeholder docstring"""
239436
# we do not pickle for jumpstart. set to none
@@ -254,6 +451,8 @@ def _build_for_jumpstart(self):
254451
self.image_uri = self.pysdk_model.image_uri
255452

256453
self._build_for_djl_jumpstart()
454+
455+
self.pysdk_model.tune = self.tune_for_djl_jumpstart
257456
elif "tgi-inference" in image_uri:
258457
logger.info("Building for TGI JumpStart Model ID...")
259458
self.model_server = ModelServer.TGI
@@ -262,6 +461,8 @@ def _build_for_jumpstart(self):
262461
self.image_uri = self.pysdk_model.image_uri
263462

264463
self._build_for_tgi_jumpstart()
464+
465+
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
265466
else:
266467
raise ValueError(
267468
"JumpStart Model ID was not packaged with djl-inference or tgi-inference container."

src/sagemaker/serve/model_server/tgi/prepare.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Prepare TgiModel for Deployment"""
1414

1515
from __future__ import absolute_import
16+
17+
import json
1618
import tarfile
1719
import logging
1820
from typing import List
@@ -32,7 +34,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
3234
custom_extractall_tarfile(resources, code_dir)
3335

3436

35-
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:
37+
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> tuple:
3638
"""Copy the associated JumpStart Resource into the code directory"""
3739
logger.info("Downloading JumpStart artifacts from S3...")
3840

@@ -56,7 +58,13 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bo
5658
else:
5759
raise ValueError("JumpStart model data compression format is unsupported: %s", model_data)
5860

59-
return True
61+
config_json_file = code_dir.joinpath("config.json")
62+
hf_model_config = None
63+
if config_json_file.is_file():
64+
with open(str(config_json_file)) as config_json:
65+
hf_model_config = json.load(config_json)
66+
67+
return (hf_model_config, True)
6068

6169

6270
def _create_dir_structure(model_path: str) -> tuple:
@@ -82,7 +90,7 @@ def prepare_tgi_js_resources(
8290
shared_libs: List[str] = None,
8391
dependencies: str = None,
8492
model_data: str = None,
85-
) -> bool:
93+
) -> tuple:
8694
"""Prepare serving when a JumpStart model id is given
8795
8896
Args:

src/sagemaker/serve/utils/tuning.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Holds mixin logic to support deployment of Model ID"""
22
from __future__ import absolute_import
3+
34
import logging
45
from time import perf_counter
56
import collections
@@ -98,6 +99,52 @@ def _pretty_print_results_tgi(results: dict):
9899
)
99100

100101

102+
def _pretty_print_results_jumpstart(results: dict, model_env_vars=None):
103+
"""Pretty prints benchmark results"""
104+
if model_env_vars is None:
105+
model_env_vars = []
106+
107+
__env_var_data = {}
108+
for model_env_var in model_env_vars:
109+
__env_var_data[model_env_var] = []
110+
111+
avg_latencies = []
112+
p90s = []
113+
avg_tokens_per_seconds = []
114+
throughput_per_seconds = []
115+
standard_deviations = []
116+
ordered = collections.OrderedDict(sorted(results.items()))
117+
118+
for key, value in ordered.items():
119+
avg_latencies.append(key)
120+
p90s.append(value[1])
121+
avg_tokens_per_seconds.append(value[2])
122+
throughput_per_seconds.append(value[3])
123+
standard_deviations.append(value[4])
124+
125+
for model_env_var in __env_var_data:
126+
__env_var_data[model_env_var].append(value[0][model_env_var])
127+
128+
df = pd.DataFrame(
129+
{
130+
"AverageLatency (Serial)": avg_latencies,
131+
"P90_Latency (Serial)": p90s,
132+
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
133+
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
134+
"StandardDeviationResponse (Concurrent)": standard_deviations,
135+
**__env_var_data,
136+
}
137+
)
138+
139+
logger.info(
140+
"\n================================================================== Benchmark "
141+
"Results ==================================================================\n%s"
142+
"\n============================================================================"
143+
"===========================================================================\n",
144+
df.to_string(),
145+
)
146+
147+
101148
def _tokens_per_second(generated_text: str, max_token_length: int, latency: float) -> int:
102149
"""Placeholder docstring"""
103150
est_tokens = (_tokens_from_chars(generated_text) + _tokens_from_words(generated_text)) / 2
@@ -216,3 +263,24 @@ def _more_performant(best_tuned_configuration: list, tuned_configuration: list)
216263
return True
217264
return False
218265
return tuned_avg_latency <= best_avg_latency
266+
267+
268+
def _sharded_supported(model_id: str, config_dict: dict) -> bool:
269+
"""Check if sharded is supported for this ``Model``"""
270+
model_type = config_dict.get("model_type", None)
271+
272+
if model_type is None:
273+
return False
274+
275+
if model_id.startswith("facebook/galactica"):
276+
return True
277+
278+
if model_type in ["bloom", "mpt", "ssm", "gpt_neox", "phi", "phi-msft", "opt", "t5"]:
279+
return True
280+
281+
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"] and not config_dict.get(
282+
"alibi", False
283+
):
284+
return True
285+
286+
return False

tests/integ/sagemaker/serve/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
SERVE_IN_PROCESS_TIMEOUT = 5
2121
SERVE_MODEL_PACKAGE_TIMEOUT = 10
2222
SERVE_LOCAL_CONTAINER_TIMEOUT = 10
23+
SERVE_LOCAL_CONTAINER_TUNE_TIMEOUT = 15
2324
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT = 15
2425
SERVE_SAVE_TIMEOUT = 2
2526

0 commit comments

Comments
 (0)