Skip to content

Commit e57c850

Browse files
authored
Updating Inference Optimization Validations (#4971)
* Updating Inference Optimization Validations * Linting
1 parent f34b41b commit e57c850

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,15 +1433,15 @@ def _model_builder_optimize_wrapper(
14331433

14341434
# HF Model ID format = "meta-llama/Meta-Llama-3.1-8B"
14351435
# JS Model ID format = "meta-textgeneration-llama-3-1-8b"
1436-
llama_3_1_keywords = ["llama-3.1", "llama-3-1"]
1437-
is_llama_3_1 = self.model and any(
1438-
keyword in self.model.lower() for keyword in llama_3_1_keywords
1436+
is_llama_3_plus = self.model and bool(
1437+
re.search(r"llama-3[\.\-][1-9]\d*", self.model.lower())
14391438
)
14401439

14411440
if is_gpu_instance and self.model and self.is_compiled:
1442-
if is_llama_3_1:
1441+
if is_llama_3_plus:
14431442
raise ValueError(
1444-
"Compilation is not supported for Llama-3.1 with a GPU instance."
1443+
"Compilation is not supported for models greater "
1444+
"than Llama-3.0 with a GPU instance."
14451445
)
14461446
if speculative_decoding_config:
14471447
raise ValueError(

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3270,7 +3270,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
32703270

32713271
mock_pysdk_model = Mock()
32723272
mock_pysdk_model.model_data = None
3273-
mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-1-8B-Instruct"}
3273+
mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-2-8B-Instruct"}
32743274

32753275
sample_input = {"inputs": "dummy prompt", "parameters": {}}
32763276

@@ -3279,7 +3279,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
32793279
dummy_schema_builder = SchemaBuilder(sample_input, sample_output)
32803280

32813281
model_builder = ModelBuilder(
3282-
model="meta-llama/Meta-Llama-3-1-8B-Instruct",
3282+
model="meta-llama/Meta-Llama-3-2-8B-Instruct",
32833283
schema_builder=dummy_schema_builder,
32843284
env_vars={"HF_TOKEN": "token"},
32853285
model_metadata={
@@ -3293,7 +3293,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
32933293

32943294
self.assertRaisesRegex(
32953295
ValueError,
3296-
"Compilation is not supported for Llama-3.1 with a GPU instance.",
3296+
"Compilation is not supported for models greater than Llama-3.0 with a GPU instance.",
32973297
lambda: model_builder.optimize(
32983298
job_name="job_name-123",
32993299
instance_type="ml.g5.24xlarge",

0 commit comments

Comments
 (0)