Skip to content

Commit c615842

Browse files
author
Jonathan Makunga
committed
Refactoring
1 parent 3fe54c0 commit c615842

File tree

3 files changed

+0
-92
lines changed

3 files changed

+0
-92
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
4444
from sagemaker.serve.utils.tuning import (
4545
_pretty_print_benchmark_results,
46-
sharded_supported,
4746
_serial_benchmark,
4847
_concurrent_benchmark,
4948
_more_performant,
@@ -278,15 +277,6 @@ def _tune_for_js(self, max_tuning_duration: int = 1800):
278277
self.js_model_config
279278
)
280279

281-
if len(admissible_tensor_parallel_degrees) > 1 and not sharded_supported(
282-
self.model, self.js_model_config
283-
):
284-
admissible_tensor_parallel_degrees = [1]
285-
logger.warning(
286-
"Sharded across multiple GPUs is not supported for this model. "
287-
"Model can only be sharded across [1] GPU"
288-
)
289-
290280
benchmark_results = {}
291281
best_tuned_combination = None
292282
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)

src/sagemaker/serve/utils/tuning.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -263,24 +263,3 @@ def _more_performant(best_tuned_configuration: list, tuned_configuration: list)
263263
return True
264264
return False
265265
return tuned_avg_latency <= best_avg_latency
266-
267-
268-
def sharded_supported(model_id: str, config_dict: dict) -> bool:
269-
"""Check if sharded is supported for this ``Model``"""
270-
model_type = config_dict.get("model_type", None)
271-
272-
if model_type is None:
273-
return False
274-
275-
if model_id.startswith("facebook/galactica"):
276-
return True
277-
278-
if model_type in ["bloom", "mpt", "ssm", "gpt_neox", "phi", "phi-msft", "opt", "t5"]:
279-
return True
280-
281-
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"] and not config_dict.get(
282-
"alibi", False
283-
):
284-
return True
285-
286-
return False

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -470,64 +470,3 @@ def test_tune_for_djl_js_local_container_invoke_ex(
470470

471471
tuned_model = model.tune()
472472
assert tuned_model.env == mock_djl_model_serving_properties
473-
474-
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
475-
@patch(
476-
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
477-
return_value=True,
478-
)
479-
@patch(
480-
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
481-
return_value=MagicMock(),
482-
)
483-
@patch(
484-
"sagemaker.serve.builder.jumpstart_builder.prepare_djl_js_resources",
485-
return_value=(
486-
mock_set_serving_properties,
487-
{"model_type": "sharded_not_enabled", "n_head": 71},
488-
True,
489-
),
490-
)
491-
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
492-
@patch(
493-
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
494-
)
495-
@patch(
496-
"sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees",
497-
return_value=[4, 2, 1],
498-
)
499-
@patch(
500-
"sagemaker.serve.utils.tuning._serial_benchmark",
501-
side_effect=[(5, 5, 25), (5.4, 5.4, 20), (5.2, 5.2, 15)],
502-
)
503-
@patch(
504-
"sagemaker.serve.utils.tuning._concurrent_benchmark",
505-
side_effect=[(0.9, 1), (0.10, 4), (0.13, 2)],
506-
)
507-
def test_tune_for_djl_js_local_container_sharded_not_enabled(
508-
self,
509-
mock_concurrent_benchmarks,
510-
mock_serial_benchmarks,
511-
mock_admissible_tensor_parallel_degrees,
512-
mock_get_nb_instance,
513-
mock_get_ram_usage_mb,
514-
mock_prepare_for_tgi,
515-
mock_pre_trained_model,
516-
mock_is_jumpstart_model,
517-
mock_telemetry,
518-
):
519-
builder = ModelBuilder(
520-
model=mock_model_id,
521-
schema_builder=mock_schema_builder,
522-
mode=Mode.LOCAL_CONTAINER,
523-
)
524-
525-
mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri
526-
527-
model = builder.build()
528-
builder.serve_settings.telemetry_opt_out = True
529-
530-
mock_pre_trained_model.return_value.env = mock_djl_model_serving_properties
531-
532-
tuned_model = model.tune()
533-
assert tuned_model.env == mock_djl_most_performant_model_serving_properties

0 commit comments

Comments
 (0)