Skip to content

Tune (local mode) support for Jumpstart Models #4532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fc17719
JumpStart TGI tune support
Mar 18, 2024
e4a042a
TGI does not enable multiple model serving internal to the container
Mar 19, 2024
01eb282
Merge branch 'master' into js_tune
Mar 19, 2024
a2be460
Tune support for JS model with DJL DLC
Mar 19, 2024
332b2ec
Pretty Print JS TGI Benchmark results
Mar 20, 2024
4d06e66
Pretty Print JS DJL Benchmark results
Mar 20, 2024
1ec29e3
Debugging
Mar 20, 2024
4be1679
Refactoring
Mar 21, 2024
62d00b3
Revert "Refactoring"
Mar 21, 2024
2952dd3
refactoring
Mar 22, 2024
b0c0d16
Revert "Debugging"
Mar 22, 2024
94a2bbf
Add Tests
Mar 23, 2024
6b7f5c1
Refactoring
Mar 23, 2024
223b73a
Refactoring
Mar 23, 2024
b96e78f
Debugging
Mar 23, 2024
e08e1c0
Refactoring and Debugging
Mar 23, 2024
15b2890
Refactoring
Mar 25, 2024
c8289ff
Merge branch 'master' into js_tune
Mar 25, 2024
d0ea680
Merge pull request #2 from makungaj1/js_tune
makungaj1 Mar 25, 2024
1b4c864
Refactoring
Mar 25, 2024
4e9c32b
Merge branch 'master' into master
makungaj1 Mar 26, 2024
1d84fb8
Refactoring
Mar 26, 2024
ad6cb03
Add unit tests
Mar 26, 2024
bd80a67
Address PR Review Comments
Mar 26, 2024
7d40674
Sharded is not supported for all model types
Mar 27, 2024
60c4311
Merge branch 'master' into master
makungaj1 Mar 27, 2024
8227a4c
Refactoring
Mar 27, 2024
aa3a351
Refactoring
Mar 27, 2024
266fe2b
Refactoring
Mar 27, 2024
9b18684
Refactoring
Mar 27, 2024
3fe54c0
Merge branch 'master' into master
makungaj1 Mar 27, 2024
c615842
Refactoring
Mar 28, 2024
e63de84
Addressed PR Review comments
Mar 28, 2024
da2c0a2
Sharding support validation for TGI
Mar 28, 2024
921f7d4
Fix unit tests
Mar 28, 2024
d7c85e7
Increase code coverage
Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 204 additions & 3 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,41 @@
"""Placeholder docstring"""
from __future__ import absolute_import

import copy
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Type
import logging

from sagemaker.model import Model
from sagemaker import model_uris
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.exceptions import (
LocalDeepPingException,
LocalModelOutOfMemoryException,
LocalModelInvocationException,
LocalModelLoadException,
SkipTuningComboException,
)
from sagemaker.serve.utils.predictors import (
DjlLocalModePredictor,
TgiLocalModePredictor,
)
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb
from sagemaker.serve.utils.local_hardware import (
_get_nb_instance,
_get_ram_usage_mb,
)
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.serve.utils.tuning import (
_pretty_print_results_jumpstart,
_serial_benchmark,
_concurrent_benchmark,
_more_performant,
_sharded_supported,
)
from sagemaker.serve.utils.types import ModelServer
from sagemaker.base_predictor import PredictorBase
from sagemaker.jumpstart.model import JumpStartModel
Expand Down Expand Up @@ -134,7 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
model_data=self.pysdk_model.model_data,
)
elif not hasattr(self, "prepared_for_tgi"):
self.prepared_for_tgi = prepare_tgi_js_resources(
self.js_model_config, self.prepared_for_tgi = prepare_tgi_js_resources(
model_path=self.model_path,
js_id=self.model,
dependencies=self.dependencies,
Expand Down Expand Up @@ -222,7 +242,7 @@ def _build_for_tgi_jumpstart(self):
env = {}
if self.mode == Mode.LOCAL_CONTAINER:
if not hasattr(self, "prepared_for_tgi"):
self.prepared_for_tgi = prepare_tgi_js_resources(
self.js_model_config, self.prepared_for_tgi = prepare_tgi_js_resources(
model_path=self.model_path,
js_id=self.model,
dependencies=self.dependencies,
Expand All @@ -234,6 +254,183 @@ def _build_for_tgi_jumpstart(self):

self.pysdk_model.env.update(env)

def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800):
"""Tune for Jumpstart Models in Local Mode.

Args:
sharded_supported (bool): Indicates whether sharding is supported by this ``Model``
max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
Default: ``1800``
returns:
Tuned Model.
"""
if self.mode != Mode.LOCAL_CONTAINER:
logger.warning(
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
)
return self.pysdk_model

num_shard_env_var_name = "SM_NUM_GPUS"
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"

initial_env_vars = copy.deepcopy(self.pysdk_model.env)
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(
self.js_model_config
)

if len(admissible_tensor_parallel_degrees) > 1 and not sharded_supported:
admissible_tensor_parallel_degrees = [1]
logger.warning(
"Sharding across multiple GPUs is not supported for this model. "
"Model can only be sharded across [1] GPU"
)

benchmark_results = {}
best_tuned_combination = None
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
if datetime.now() > timeout:
logger.info("Max tuning duration reached. Tuning stopped.")
break

self.pysdk_model.env.update({num_shard_env_var_name: str(tensor_parallel_degree)})
try:
logger.info("Trying tensor parallel degree: %s", tensor_parallel_degree)

predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration)

avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
predictor, self.schema_builder.sample_input
)
throughput_per_second, standard_deviation = _concurrent_benchmark(
predictor, self.schema_builder.sample_input
)

tested_env = copy.deepcopy(self.pysdk_model.env)
logger.info(
"Average latency: %s, throughput/s: %s for configuration: %s",
avg_latency,
throughput_per_second,
tested_env,
)
benchmark_results[avg_latency] = [
tested_env,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]

if not best_tuned_combination:
best_tuned_combination = [
avg_latency,
tensor_parallel_degree,
None,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
else:
tuned_configuration = [
avg_latency,
tensor_parallel_degree,
None,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
if _more_performant(best_tuned_combination, tuned_configuration):
best_tuned_combination = tuned_configuration
except LocalDeepPingException as e:
logger.warning(
"Deployment unsuccessful with %s: %s. " "Failed to invoke the model server: %s",
num_shard_env_var_name,
tensor_parallel_degree,
str(e),
)
except LocalModelOutOfMemoryException as e:
logger.warning(
"Deployment unsuccessful with %s: %s. "
"Out of memory when loading the model: %s",
num_shard_env_var_name,
tensor_parallel_degree,
str(e),
)
except LocalModelInvocationException as e:
logger.warning(
"Deployment unsuccessful with %s: %s. "
"Failed to invoke the model server: %s"
"Please check that model server configurations are as expected "
"(Ex. serialization, deserialization, content_type, accept).",
num_shard_env_var_name,
tensor_parallel_degree,
str(e),
)
except LocalModelLoadException as e:
logger.warning(
"Deployment unsuccessful with %s: %s. " "Failed to load the model: %s.",
num_shard_env_var_name,
tensor_parallel_degree,
str(e),
)
except SkipTuningComboException as e:
logger.warning(
"Deployment with %s: %s"
"was expected to be successful. However failed with: %s. "
"Trying next combination.",
num_shard_env_var_name,
tensor_parallel_degree,
str(e),
)
except Exception: # pylint: disable=W0703
logger.exception(
"Deployment unsuccessful with %s: %s. " "with uncovered exception",
num_shard_env_var_name,
tensor_parallel_degree,
)

if best_tuned_combination:
self.pysdk_model.env.update({num_shard_env_var_name: str(best_tuned_combination[1])})

_pretty_print_results_jumpstart(benchmark_results, [num_shard_env_var_name])
logger.info(
"Model Configuration: %s was most performant with avg latency: %s, "
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
"standard deviation of request %s",
self.pysdk_model.env,
best_tuned_combination[0],
best_tuned_combination[3],
best_tuned_combination[4],
best_tuned_combination[5],
best_tuned_combination[6],
)
else:
self.pysdk_model.env.update(initial_env_vars)
logger.debug(
"Failed to gather any tuning results. "
"Please inspect the stack trace emitted from live logging for more details. "
"Falling back to default model configurations: %s",
self.pysdk_model.env,
)

return self.pysdk_model

@_capture_telemetry("djl_jumpstart.tune")
def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
"""Tune for Jumpstart Models with DJL DLC"""
return self._tune_for_js(sharded_supported=True, max_tuning_duration=max_tuning_duration)

@_capture_telemetry("tgi_jumpstart.tune")
def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
"""Tune for Jumpstart Models with TGI DLC"""
sharded_supported = _sharded_supported(self.model, self.js_model_config)
return self._tune_for_js(
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
)

def _build_for_jumpstart(self):
"""Placeholder docstring"""
# we do not pickle for jumpstart. set to none
Expand All @@ -254,6 +451,8 @@ def _build_for_jumpstart(self):
self.image_uri = self.pysdk_model.image_uri

self._build_for_djl_jumpstart()

self.pysdk_model.tune = self.tune_for_djl_jumpstart
elif "tgi-inference" in image_uri:
logger.info("Building for TGI JumpStart Model ID...")
self.model_server = ModelServer.TGI
Expand All @@ -262,6 +461,8 @@ def _build_for_jumpstart(self):
self.image_uri = self.pysdk_model.image_uri

self._build_for_tgi_jumpstart()

self.pysdk_model.tune = self.tune_for_tgi_jumpstart
else:
raise ValueError(
"JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
Expand Down
14 changes: 11 additions & 3 deletions src/sagemaker/serve/model_server/tgi/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"""Prepare TgiModel for Deployment"""

from __future__ import absolute_import

import json
import tarfile
import logging
from typing import List
Expand All @@ -32,7 +34,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
custom_extractall_tarfile(resources, code_dir)


def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> tuple:
"""Copy the associated JumpStart Resource into the code directory"""
logger.info("Downloading JumpStart artifacts from S3...")

Expand All @@ -56,7 +58,13 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bo
else:
raise ValueError("JumpStart model data compression format is unsupported: %s", model_data)

return True
config_json_file = code_dir.joinpath("config.json")
hf_model_config = None
if config_json_file.is_file():
with open(str(config_json_file)) as config_json:
hf_model_config = json.load(config_json)

return (hf_model_config, True)


def _create_dir_structure(model_path: str) -> tuple:
Expand All @@ -82,7 +90,7 @@ def prepare_tgi_js_resources(
shared_libs: List[str] = None,
dependencies: str = None,
model_data: str = None,
) -> bool:
) -> tuple:
"""Prepare serving when a JumpStart model id is given
Args:
Expand Down
68 changes: 68 additions & 0 deletions src/sagemaker/serve/utils/tuning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Holds mixin logic to support deployment of Model ID"""
from __future__ import absolute_import

import logging
from time import perf_counter
import collections
Expand Down Expand Up @@ -98,6 +99,52 @@ def _pretty_print_results_tgi(results: dict):
)


def _pretty_print_results_jumpstart(results: dict, model_env_vars=None):
"""Pretty prints benchmark results"""
if model_env_vars is None:
model_env_vars = []

__env_var_data = {}
for model_env_var in model_env_vars:
__env_var_data[model_env_var] = []

avg_latencies = []
p90s = []
avg_tokens_per_seconds = []
throughput_per_seconds = []
standard_deviations = []
ordered = collections.OrderedDict(sorted(results.items()))

for key, value in ordered.items():
avg_latencies.append(key)
p90s.append(value[1])
avg_tokens_per_seconds.append(value[2])
throughput_per_seconds.append(value[3])
standard_deviations.append(value[4])

for model_env_var in __env_var_data:
__env_var_data[model_env_var].append(value[0][model_env_var])

df = pd.DataFrame(
{
"AverageLatency (Serial)": avg_latencies,
"P90_Latency (Serial)": p90s,
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
"StandardDeviationResponse (Concurrent)": standard_deviations,
**__env_var_data,
}
)

logger.info(
"\n================================================================== Benchmark "
"Results ==================================================================\n%s"
"\n============================================================================"
"===========================================================================\n",
df.to_string(),
)


def _tokens_per_second(generated_text: str, max_token_length: int, latency: float) -> int:
"""Placeholder docstring"""
est_tokens = (_tokens_from_chars(generated_text) + _tokens_from_words(generated_text)) / 2
Expand Down Expand Up @@ -216,3 +263,24 @@ def _more_performant(best_tuned_configuration: list, tuned_configuration: list)
return True
return False
return tuned_avg_latency <= best_avg_latency


def _sharded_supported(model_id: str, config_dict: dict) -> bool:
"""Check if sharded is supported for this ``Model``"""
model_type = config_dict.get("model_type", None)

if model_type is None:
return False

if model_id.startswith("facebook/galactica"):
return True

if model_type in ["bloom", "mpt", "ssm", "gpt_neox", "phi", "phi-msft", "opt", "t5"]:
return True

if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"] and not config_dict.get(
"alibi", False
):
return True

return False
1 change: 1 addition & 0 deletions tests/integ/sagemaker/serve/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SERVE_IN_PROCESS_TIMEOUT = 5
SERVE_MODEL_PACKAGE_TIMEOUT = 10
SERVE_LOCAL_CONTAINER_TIMEOUT = 10
SERVE_LOCAL_CONTAINER_TUNE_TIMEOUT = 15
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT = 15
SERVE_SAVE_TIMEOUT = 2

Expand Down
Loading