@@ -484,16 +484,31 @@ def test_tune_for_djl_js_local_container_invoke_ex(
484
484
"sagemaker.serve.builder.jumpstart_builder.prepare_djl_js_resources" ,
485
485
return_value = (
486
486
mock_set_serving_properties ,
487
- {"model_type" : "sharded_not_supported " , "n_head" : 71 },
487
+ {"model_type" : "sharded_not_enabled " , "n_head" : 71 },
488
488
True ,
489
489
),
490
490
)
491
491
@patch ("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb" , return_value = 1024 )
492
492
@patch (
493
493
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance" , return_value = "ml.g5.24xlarge"
494
494
)
495
+ @patch (
496
+ "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees" ,
497
+ return_value = [4 , 2 , 1 ],
498
+ )
499
+ @patch (
500
+ "sagemaker.serve.utils.tuning._serial_benchmark" ,
501
+ side_effect = [(5 , 5 , 25 ), (5.4 , 5.4 , 20 ), (5.2 , 5.2 , 15 )],
502
+ )
503
+ @patch (
504
+ "sagemaker.serve.utils.tuning._concurrent_benchmark" ,
505
+ side_effect = [(0.9 , 1 ), (0.10 , 4 ), (0.13 , 2 )],
506
+ )
495
507
def test_tune_for_djl_js_local_container_sharded_not_enabled (
496
508
self ,
509
+ mock_concurrent_benchmarks ,
510
+ mock_serial_benchmarks ,
511
+ mock_admissible_tensor_parallel_degrees ,
497
512
mock_get_nb_instance ,
498
513
mock_get_ram_usage_mb ,
499
514
mock_prepare_for_tgi ,
@@ -502,7 +517,9 @@ def test_tune_for_djl_js_local_container_sharded_not_enabled(
502
517
mock_telemetry ,
503
518
):
504
519
builder = ModelBuilder (
505
- model = mock_model_id , schema_builder = mock_schema_builder , mode = Mode .LOCAL_CONTAINER
520
+ model = mock_model_id ,
521
+ schema_builder = mock_schema_builder ,
522
+ mode = Mode .LOCAL_CONTAINER ,
506
523
)
507
524
508
525
mock_pre_trained_model .return_value .image_uri = mock_djl_image_uri
@@ -513,4 +530,8 @@ def test_tune_for_djl_js_local_container_sharded_not_enabled(
513
530
mock_pre_trained_model .return_value .env = mock_djl_model_serving_properties
514
531
515
532
tuned_model = model .tune ()
516
- assert tuned_model .env == mock_djl_model_serving_properties
533
+ assert tuned_model .env == {
534
+ "SAGEMAKER_PROGRAM" : "inference.py" ,
535
+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
536
+ "OPTION_TENSOR_PARALLEL_DEGREE" : "1" ,
537
+ }
0 commit comments