Skip to content

Js tune #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 25, 2024
237 changes: 234 additions & 3 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,41 @@
"""Placeholder docstring"""
from __future__ import absolute_import

import copy
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Type
import logging

from sagemaker.model import Model
from sagemaker import model_uris
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.exceptions import (
LocalDeepPingException,
LocalModelOutOfMemoryException,
LocalModelInvocationException,
LocalModelLoadException,
SkipTuningComboException,
)
from sagemaker.serve.utils.predictors import (
DjlLocalModePredictor,
TgiLocalModePredictor,
)
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb
from sagemaker.serve.utils.local_hardware import (
_get_nb_instance,
_get_ram_usage_mb,
_get_available_gpus,
)
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.serve.utils.tuning import (
_serial_benchmark,
_concurrent_benchmark,
_more_performant,
_pretty_print_benchmark_results,
)
from sagemaker.serve.utils.types import ModelServer
from sagemaker.base_predictor import PredictorBase
from sagemaker.jumpstart.model import JumpStartModel
Expand Down Expand Up @@ -134,7 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
model_data=self.pysdk_model.model_data,
)
elif not hasattr(self, "prepared_for_tgi"):
self.prepared_for_tgi = prepare_tgi_js_resources(
(self.js_model_config, self.prepared_for_tgi) = prepare_tgi_js_resources(
model_path=self.model_path,
js_id=self.model,
dependencies=self.dependencies,
Expand Down Expand Up @@ -222,7 +242,7 @@ def _build_for_tgi_jumpstart(self):
env = {}
if self.mode == Mode.LOCAL_CONTAINER:
if not hasattr(self, "prepared_for_tgi"):
self.prepared_for_tgi = prepare_tgi_js_resources(
(self.js_model_config, self.prepared_for_tgi) = prepare_tgi_js_resources(
model_path=self.model_path,
js_id=self.model,
dependencies=self.dependencies,
Expand All @@ -234,6 +254,213 @@ def _build_for_tgi_jumpstart(self):

self.pysdk_model.env.update(env)

def _tune_for_js(
self, multiple_model_copies_enabled: bool = False, max_tuning_duration: int = 1800
):
"""Tune for Jumpstart Models.

Args:
multiple_model_copies_enabled (bool): Whether multiple model copies serving is enable by
this ``DLC``. Defaults to ``False``
max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
Default: ``1800``
returns:
Tuned Model.
"""
if self.mode != Mode.LOCAL_CONTAINER:
logger.warning(
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
)
return self.pysdk_model

num_shard_env_var_name = "SM_NUM_GPUS"
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"

initial_env_vars = copy.deepcopy(self.pysdk_model.env)
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(
self.js_model_config
)
available_gpus = _get_available_gpus() if multiple_model_copies_enabled else None

benchmark_results = {}
best_tuned_combination = None
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
if datetime.now() > timeout:
logger.info("Max tuning duration reached. Tuning stopped.")
break

sagemaker_model_server_workers = 1
if multiple_model_copies_enabled:
sagemaker_model_server_workers = int(available_gpus / tensor_parallel_degree)

self.pysdk_model.env.update(
{
num_shard_env_var_name: str(tensor_parallel_degree),
"SAGEMAKER_MODEL_SERVER_WORKERS": str(sagemaker_model_server_workers),
}
)

try:
predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration)

avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
predictor, self.schema_builder.sample_input
)
throughput_per_second, standard_deviation = _concurrent_benchmark(
predictor, self.schema_builder.sample_input
)

tested_env = self.pysdk_model.env.copy()
logger.info(
"Average latency: %s, throughput/s: %s for configuration: %s",
avg_latency,
throughput_per_second,
tested_env,
)
benchmark_results[avg_latency] = [
tested_env,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]

if not best_tuned_combination:
best_tuned_combination = [
avg_latency,
tensor_parallel_degree,
sagemaker_model_server_workers,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
else:
tuned_configuration = [
avg_latency,
tensor_parallel_degree,
sagemaker_model_server_workers,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
if _more_performant(best_tuned_combination, tuned_configuration):
best_tuned_combination = tuned_configuration
except LocalDeepPingException as e:
logger.warning(
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s"
"Failed to invoke the model server: %s",
num_shard_env_var_name,
tensor_parallel_degree,
sagemaker_model_server_workers,
str(e),
)
break
except LocalModelOutOfMemoryException as e:
logger.warning(
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
"Out of memory when loading the model: %s",
num_shard_env_var_name,
tensor_parallel_degree,
sagemaker_model_server_workers,
str(e),
)
break
except LocalModelInvocationException as e:
logger.warning(
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
"Failed to invoke the model server: %s"
"Please check that model server configurations are as expected "
"(Ex. serialization, deserialization, content_type, accept).",
num_shard_env_var_name,
tensor_parallel_degree,
sagemaker_model_server_workers,
str(e),
)
break
except LocalModelLoadException as e:
logger.warning(
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
"Failed to load the model: %s.",
num_shard_env_var_name,
tensor_parallel_degree,
sagemaker_model_server_workers,
str(e),
)
break
except SkipTuningComboException as e:
logger.warning(
"Deployment with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
"was expected to be successful. However failed with: %s. "
"Trying next combination.",
num_shard_env_var_name,
tensor_parallel_degree,
sagemaker_model_server_workers,
str(e),
)
break
except Exception: # pylint: disable=W0703
logger.exception(
"Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
"with uncovered exception",
num_shard_env_var_name,
tensor_parallel_degree,
sagemaker_model_server_workers,
)
break

if best_tuned_combination:
self.pysdk_model.env.update(
{
num_shard_env_var_name: str(best_tuned_combination[1]),
"SAGEMAKER_MODEL_SERVER_WORKERS": str(best_tuned_combination[2]),
}
)

_pretty_print_benchmark_results(
benchmark_results, [num_shard_env_var_name, "SAGEMAKER_MODEL_SERVER_WORKERS"]
)
logger.info(
"Model Configuration: %s was most performant with avg latency: %s, "
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
"standard deviation of request %s",
self.pysdk_model.env,
best_tuned_combination[0],
best_tuned_combination[3],
best_tuned_combination[4],
best_tuned_combination[5],
best_tuned_combination[6],
)
else:
self.pysdk_model.env.update(initial_env_vars)
logger.debug(
"Failed to gather any tuning results. "
"Please inspect the stack trace emitted from live logging for more details. "
"Falling back to default model configurations: %s",
self.pysdk_model.env,
)

return self.pysdk_model

@_capture_telemetry("djl_jumpstart.tune")
def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
"""Tune for Jumpstart Models with DJL DLC"""
return self._tune_for_js(
multiple_model_copies_enabled=True, max_tuning_duration=max_tuning_duration
)

@_capture_telemetry("tgi_jumpstart.tune")
def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
"""Tune for Jumpstart Models with TGI DLC"""
return self._tune_for_js(
# Currently, TGI does not enable multiple model copies serving.
multiple_model_copies_enabled=False,
max_tuning_duration=max_tuning_duration,
)

def _build_for_jumpstart(self):
"""Placeholder docstring"""
# we do not pickle for jumpstart. set to none
Expand All @@ -254,6 +481,8 @@ def _build_for_jumpstart(self):
self.image_uri = self.pysdk_model.image_uri

self._build_for_djl_jumpstart()

self.pysdk_model.tune = self.tune_for_djl_jumpstart
elif "tgi-inference" in image_uri:
logger.info("Building for TGI JumpStart Model ID...")
self.model_server = ModelServer.TGI
Expand All @@ -262,6 +491,8 @@ def _build_for_jumpstart(self):
self.image_uri = self.pysdk_model.image_uri

self._build_for_tgi_jumpstart()

self.pysdk_model.tune = self.tune_for_tgi_jumpstart
else:
raise ValueError(
"JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
Expand Down
14 changes: 11 additions & 3 deletions src/sagemaker/serve/model_server/tgi/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"""Prepare TgiModel for Deployment"""

from __future__ import absolute_import

import json
import tarfile
import logging
from typing import List
Expand All @@ -32,7 +34,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
custom_extractall_tarfile(resources, code_dir)


def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> tuple:
"""Copy the associated JumpStart Resource into the code directory"""
logger.info("Downloading JumpStart artifacts from S3...")

Expand All @@ -56,7 +58,13 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bo
else:
raise ValueError("JumpStart model data compression format is unsupported: %s", model_data)

return True
config_json_file = code_dir.joinpath("config.json")
hf_model_config = None
if config_json_file.is_file():
with open(str(config_json_file)) as config_json:
hf_model_config = json.load(config_json)

return (hf_model_config, True)


def _create_dir_structure(model_path: str) -> tuple:
Expand All @@ -82,7 +90,7 @@ def prepare_tgi_js_resources(
shared_libs: List[str] = None,
dependencies: str = None,
model_data: str = None,
) -> bool:
) -> tuple:
"""Prepare serving when a JumpStart model id is given

Args:
Expand Down
45 changes: 45 additions & 0 deletions src/sagemaker/serve/utils/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,51 @@ def _pretty_print_results_tgi(results: dict):
)


def _pretty_print_benchmark_results(results: dict, model_env_vars=None):
"""Placeholder docstring"""
if model_env_vars is None:
model_env_vars = []

__env_var_data = {}
for e in model_env_vars:
__env_var_data[e] = []

avg_latencies = []
p90s = []
avg_tokens_per_seconds = []
throughput_per_seconds = []
standard_deviations = []
ordered = collections.OrderedDict(sorted(results.items()))

for key, value in ordered.items():
avg_latencies.append(key)
p90s.append(value[1])
avg_tokens_per_seconds.append(value[2])
throughput_per_seconds.append(value[3])
standard_deviations.append(value[4])

for k in __env_var_data:
__env_var_data[k].append(value[0][k])

df = pd.DataFrame(
{
"AverageLatency (Serial)": avg_latencies,
"P90_Latency (Serial)": p90s,
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
"StandardDeviationResponse (Concurrent)": standard_deviations,
**__env_var_data,
}
)
logger.info(
"\n================================================================== Benchmark "
"Results ==================================================================\n%s"
"\n============================================================================"
"===========================================================================\n",
df.to_string(),
)


def _tokens_per_second(generated_text: str, max_token_length: int, latency: float) -> int:
"""Placeholder docstring"""
est_tokens = (_tokens_from_chars(generated_text) + _tokens_from_words(generated_text)) / 2
Expand Down
1 change: 1 addition & 0 deletions tests/integ/sagemaker/serve/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SERVE_IN_PROCESS_TIMEOUT = 5
SERVE_MODEL_PACKAGE_TIMEOUT = 10
SERVE_LOCAL_CONTAINER_TIMEOUT = 10
SERVE_LOCAL_CONTAINER_TUNE_TIMEOUT = 15
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT = 15
SERVE_SAVE_TIMEOUT = 2

Expand Down
Loading