File tree Expand file tree Collapse file tree 2 files changed +8
-8
lines changed
src/sagemaker/serve/builder
tests/unit/sagemaker/serve/builder Expand file tree Collapse file tree 2 files changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -1433,15 +1433,15 @@ def _model_builder_optimize_wrapper(
1433
1433
1434
1434
# HF Model ID format = "meta-llama/Meta-Llama-3.1-8B"
1435
1435
# JS Model ID format = "meta-textgeneration-llama-3-1-8b"
1436
- llama_3_1_keywords = ["llama-3.1" , "llama-3-1" ]
1437
- is_llama_3_1 = self .model and any (
1438
- keyword in self .model .lower () for keyword in llama_3_1_keywords
1436
+ is_llama_3_plus = self .model and bool (
1437
+ re .search (r"llama-3[\.\-][1-9]\d*" , self .model .lower ())
1439
1438
)
1440
1439
1441
1440
if is_gpu_instance and self .model and self .is_compiled :
1442
- if is_llama_3_1 :
1441
+ if is_llama_3_plus :
1443
1442
raise ValueError (
1444
- "Compilation is not supported for Llama-3.1 with a GPU instance."
1443
+ "Compilation is not supported for models greater "
1444
+ "than Llama-3.0 with a GPU instance."
1445
1445
)
1446
1446
if speculative_decoding_config :
1447
1447
raise ValueError (
Original file line number Diff line number Diff line change @@ -3270,7 +3270,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
3270
3270
3271
3271
mock_pysdk_model = Mock ()
3272
3272
mock_pysdk_model .model_data = None
3273
- mock_pysdk_model .env = {"HF_MODEL_ID" : "meta-llama/Meta-Llama-3-1 -8B-Instruct" }
3273
+ mock_pysdk_model .env = {"HF_MODEL_ID" : "meta-llama/Meta-Llama-3-2 -8B-Instruct" }
3274
3274
3275
3275
sample_input = {"inputs" : "dummy prompt" , "parameters" : {}}
3276
3276
@@ -3279,7 +3279,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
3279
3279
dummy_schema_builder = SchemaBuilder (sample_input , sample_output )
3280
3280
3281
3281
model_builder = ModelBuilder (
3282
- model = "meta-llama/Meta-Llama-3-1 -8B-Instruct" ,
3282
+ model = "meta-llama/Meta-Llama-3-2 -8B-Instruct" ,
3283
3283
schema_builder = dummy_schema_builder ,
3284
3284
env_vars = {"HF_TOKEN" : "token" },
3285
3285
model_metadata = {
@@ -3293,7 +3293,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
3293
3293
3294
3294
self .assertRaisesRegex (
3295
3295
ValueError ,
3296
- "Compilation is not supported for Llama-3.1 with a GPU instance." ,
3296
+ "Compilation is not supported for models greater than Llama-3.0 with a GPU instance." ,
3297
3297
lambda : model_builder .optimize (
3298
3298
job_name = "job_name-123" ,
3299
3299
instance_type = "ml.g5.24xlarge" ,
You can’t perform that action at this time.
0 commit comments