Skip to content

Commit aa3a351

Browse files
author
Jonathan Makunga
committed
Refactoring
1 parent 8227a4c commit aa3a351

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,6 @@ def _tune_for_js(self, max_tuning_duration: int = 1800):
268268
)
269269
return self.pysdk_model
270270

271-
if not sharded_supported(self.model, self.js_model_config):
272-
logger.warning("Sharded is not supported for this model. Returning original model.")
273-
return self.pysdk_model
274-
275271
num_shard_env_var_name = "SM_NUM_GPUS"
276272
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
277273
num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"
@@ -281,6 +277,15 @@ def _tune_for_js(self, max_tuning_duration: int = 1800):
281277
self.js_model_config
282278
)
283279

280+
if len(admissible_tensor_parallel_degrees) > 1 and not sharded_supported(
281+
self.model, self.js_model_config
282+
):
283+
admissible_tensor_parallel_degrees = [1]
284+
logger.warning(
285+
"Sharded across multiple GPUs is not supported for this model."
286+
"\nModel can only be sharded across [1] GPU"
287+
)
288+
284289
benchmark_results = {}
285290
best_tuned_combination = None
286291
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -484,16 +484,31 @@ def test_tune_for_djl_js_local_container_invoke_ex(
484484
"sagemaker.serve.builder.jumpstart_builder.prepare_djl_js_resources",
485485
return_value=(
486486
mock_set_serving_properties,
487-
{"model_type": "sharded_not_supported", "n_head": 71},
487+
{"model_type": "sharded_not_enabled", "n_head": 71},
488488
True,
489489
),
490490
)
491491
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
492492
@patch(
493493
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
494494
)
495+
@patch(
496+
"sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees",
497+
return_value=[4, 2, 1],
498+
)
499+
@patch(
500+
"sagemaker.serve.utils.tuning._serial_benchmark",
501+
side_effect=[(5, 5, 25), (5.4, 5.4, 20), (5.2, 5.2, 15)],
502+
)
503+
@patch(
504+
"sagemaker.serve.utils.tuning._concurrent_benchmark",
505+
side_effect=[(0.9, 1), (0.10, 4), (0.13, 2)],
506+
)
495507
def test_tune_for_djl_js_local_container_sharded_not_enabled(
496508
self,
509+
mock_concurrent_benchmarks,
510+
mock_serial_benchmarks,
511+
mock_admissible_tensor_parallel_degrees,
497512
mock_get_nb_instance,
498513
mock_get_ram_usage_mb,
499514
mock_prepare_for_tgi,
@@ -502,7 +517,9 @@ def test_tune_for_djl_js_local_container_sharded_not_enabled(
502517
mock_telemetry,
503518
):
504519
builder = ModelBuilder(
505-
model=mock_model_id, schema_builder=mock_schema_builder, mode=Mode.LOCAL_CONTAINER
520+
model=mock_model_id,
521+
schema_builder=mock_schema_builder,
522+
mode=Mode.LOCAL_CONTAINER,
506523
)
507524

508525
mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri
@@ -513,4 +530,8 @@ def test_tune_for_djl_js_local_container_sharded_not_enabled(
513530
mock_pre_trained_model.return_value.env = mock_djl_model_serving_properties
514531

515532
tuned_model = model.tune()
516-
assert tuned_model.env == mock_djl_model_serving_properties
533+
assert tuned_model.env == {
534+
"SAGEMAKER_PROGRAM": "inference.py",
535+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
536+
"OPTION_TENSOR_PARALLEL_DEGREE": "1",
537+
}

0 commit comments

Comments
 (0)