|
42 | 42 | )
|
43 | 43 | from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
|
44 | 44 | from sagemaker.serve.utils.tuning import (
|
45 |
| - _more_performant_benchmark, |
46 | 45 | _pretty_print_benchmark_results,
|
47 |
| - _run_serial_and_concurrent_benchmarks, |
48 | 46 | sharded_supported,
|
| 47 | + _serial_benchmark, |
| 48 | + _concurrent_benchmark, |
| 49 | + _more_performant, |
49 | 50 | )
|
50 | 51 | from sagemaker.serve.utils.types import ModelServer
|
51 | 52 | from sagemaker.base_predictor import PredictorBase
|
@@ -298,19 +299,52 @@ def _tune_for_js(self, max_tuning_duration: int = 1800):
|
298 | 299 | try:
|
299 | 300 | logger.info("Trying tensor parallel degree: %s", tensor_parallel_degree)
|
300 | 301 |
|
301 |
| - result = _run_serial_and_concurrent_benchmarks( |
302 |
| - self.pysdk_model, self.schema_builder.sample_input, max_tuning_duration |
| 302 | + predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration) |
| 303 | + |
| 304 | + avg_latency, p90, avg_tokens_per_second = _serial_benchmark( |
| 305 | + predictor, self.schema_builder.sample_input |
| 306 | + ) |
| 307 | + throughput_per_second, standard_deviation = _concurrent_benchmark( |
| 308 | + predictor, self.schema_builder.sample_input |
303 | 309 | )
|
304 |
| - benchmark_results[result["AVG_LATENCY"]] = { |
305 |
| - "TESTED_ENV": result["TESTED_ENV"], |
306 |
| - "P90": result["P90"], |
307 |
| - "AVG_TOKENS_PER_SECOND": result["AVG_TOKENS_PER_SECOND"], |
308 |
| - "THROUGHPUT_PER_SECOND": result["THROUGHPUT_PER_SECOND"], |
309 |
| - "STD_DEVIATION": result["STD_DEVIATION"], |
310 |
| - } |
311 |
| - |
312 |
| - result[num_shard_env_var_name] = tensor_parallel_degree |
313 |
| - best_tuned_combination = _more_performant_benchmark(best_tuned_combination, result) |
| 310 | + |
| 311 | + tested_env = copy.deepcopy(self.pysdk_model.env) |
| 312 | + logger.info( |
| 313 | + "Average latency: %s, throughput/s: %s for configuration: %s", |
| 314 | + avg_latency, |
| 315 | + throughput_per_second, |
| 316 | + tested_env, |
| 317 | + ) |
| 318 | + benchmark_results[avg_latency] = [ |
| 319 | + tested_env, |
| 320 | + p90, |
| 321 | + avg_tokens_per_second, |
| 322 | + throughput_per_second, |
| 323 | + standard_deviation, |
| 324 | + ] |
| 325 | + |
| 326 | + if not best_tuned_combination: |
| 327 | + best_tuned_combination = [ |
| 328 | + avg_latency, |
| 329 | + tensor_parallel_degree, |
| 330 | + None, |
| 331 | + p90, |
| 332 | + avg_tokens_per_second, |
| 333 | + throughput_per_second, |
| 334 | + standard_deviation, |
| 335 | + ] |
| 336 | + else: |
| 337 | + tuned_configuration = [ |
| 338 | + avg_latency, |
| 339 | + tensor_parallel_degree, |
| 340 | + None, |
| 341 | + p90, |
| 342 | + avg_tokens_per_second, |
| 343 | + throughput_per_second, |
| 344 | + standard_deviation, |
| 345 | + ] |
| 346 | + if _more_performant(best_tuned_combination, tuned_configuration): |
| 347 | + best_tuned_combination = tuned_configuration |
314 | 348 | except LocalDeepPingException as e:
|
315 | 349 | logger.warning(
|
316 | 350 | "Deployment unsuccessful with %s: %s. " "Failed to invoke the model server: %s",
|
@@ -360,21 +394,19 @@ def _tune_for_js(self, max_tuning_duration: int = 1800):
|
360 | 394 | )
|
361 | 395 |
|
362 | 396 | if best_tuned_combination:
|
363 |
| - self.pysdk_model.env.update( |
364 |
| - {num_shard_env_var_name: str(best_tuned_combination[num_shard_env_var_name])} |
365 |
| - ) |
| 397 | + self.pysdk_model.env.update({num_shard_env_var_name: str(best_tuned_combination[1])}) |
366 | 398 |
|
367 | 399 | _pretty_print_benchmark_results(benchmark_results, [num_shard_env_var_name])
|
368 | 400 | logger.info(
|
369 | 401 | "Model Configuration: %s was most performant with avg latency: %s, "
|
370 | 402 | "p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
|
371 | 403 | "standard deviation of request %s",
|
372 | 404 | self.pysdk_model.env,
|
373 |
| - best_tuned_combination["AVG_LATENCY"], |
374 |
| - best_tuned_combination["P90"], |
375 |
| - best_tuned_combination["AVG_TOKENS_PER_SECOND"], |
376 |
| - best_tuned_combination["THROUGHPUT_PER_SECOND"], |
377 |
| - best_tuned_combination["STD_DEVIATION"], |
| 405 | + best_tuned_combination[0], |
| 406 | + best_tuned_combination[3], |
| 407 | + best_tuned_combination[4], |
| 408 | + best_tuned_combination[5], |
| 409 | + best_tuned_combination[6], |
378 | 410 | )
|
379 | 411 | else:
|
380 | 412 | self.pysdk_model.env.update(initial_env_vars)
|
|
0 commit comments