@@ -1247,7 +1247,7 @@ def test_build_for_transformers_happy_case_with_values(
1247
1247
mock_build_for_transformers .assert_called_once ()
1248
1248
1249
1249
@patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl" , Mock ())
1250
- @patch ("sagemaker.serve.utils.hardware_detector ._get_gpu_info" )
1250
+ @patch ("sagemaker.serve.builder.model_builder ._get_gpu_info" )
1251
1251
@patch ("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib" )
1252
1252
@patch ("sagemaker.image_uris.retrieve" )
1253
1253
@patch ("sagemaker.djl_inference.model.urllib" )
@@ -1257,16 +1257,16 @@ def test_build_for_transformers_happy_case_with_values(
1257
1257
@patch ("sagemaker.model_uris.retrieve" )
1258
1258
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1259
1259
def test_build_for_transformers_happy_case_with_valid_gpu_info (
1260
- self ,
1261
- mock_serveSettings ,
1262
- mock_model_uris_retrieve ,
1263
- mock_llm_utils_json ,
1264
- mock_llm_utils_urllib ,
1265
- mock_model_json ,
1266
- mock_model_urllib ,
1267
- mock_image_uris_retrieve ,
1268
- mock_total_inference_model_size_mib ,
1269
- mock_try_fetch_gpu_info ,
1260
+ self ,
1261
+ mock_serveSettings ,
1262
+ mock_model_uris_retrieve ,
1263
+ mock_llm_utils_json ,
1264
+ mock_llm_utils_urllib ,
1265
+ mock_model_json ,
1266
+ mock_model_urllib ,
1267
+ mock_image_uris_retrieve ,
1268
+ mock_total_inference_model_size_mib ,
1269
+ mock_try_fetch_gpu_info ,
1270
1270
):
1271
1271
mock_setting_object = mock_serveSettings .return_value
1272
1272
mock_setting_object .role_arn = mock_role_arn
@@ -1285,7 +1285,9 @@ def test_build_for_transformers_happy_case_with_valid_gpu_info(
1285
1285
1286
1286
model_builder = ModelBuilder (model = "stable-diffusion" )
1287
1287
model_builder .build (sagemaker_session = mock_session )
1288
-
1288
+ self .assertEqual (
1289
+ model_builder ._try_fetch_gpu_info (), INSTANCE_GPU_INFO [1 ] / INSTANCE_GPU_INFO [0 ]
1290
+ )
1289
1291
self .assertEqual (model_builder ._can_fit_on_single_gpu (), False )
1290
1292
1291
1293
@patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers" , Mock ())
@@ -1300,17 +1302,17 @@ def test_build_for_transformers_happy_case_with_valid_gpu_info(
1300
1302
@patch ("sagemaker.model_uris.retrieve" )
1301
1303
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1302
1304
def test_build_for_transformers_happy_case_with_valid_gpu_fallback (
1303
- self ,
1304
- mock_serveSettings ,
1305
- mock_model_uris_retrieve ,
1306
- mock_llm_utils_json ,
1307
- mock_llm_utils_urllib ,
1308
- mock_model_json ,
1309
- mock_model_urllib ,
1310
- mock_image_uris_retrieve ,
1311
- mock_total_inference_model_size_mib ,
1312
- mock_gpu_fallback ,
1313
- mock_try_fetch_gpu_info ,
1305
+ self ,
1306
+ mock_serveSettings ,
1307
+ mock_model_uris_retrieve ,
1308
+ mock_llm_utils_json ,
1309
+ mock_llm_utils_urllib ,
1310
+ mock_model_json ,
1311
+ mock_model_urllib ,
1312
+ mock_image_uris_retrieve ,
1313
+ mock_total_inference_model_size_mib ,
1314
+ mock_gpu_fallback ,
1315
+ mock_try_fetch_gpu_info ,
1314
1316
):
1315
1317
mock_setting_object = mock_serveSettings .return_value
1316
1318
mock_setting_object .role_arn = mock_role_arn
@@ -1324,12 +1326,20 @@ def test_build_for_transformers_happy_case_with_valid_gpu_fallback(
1324
1326
mock_model_urllib .request .Request .side_effect = Mock ()
1325
1327
mock_try_fetch_gpu_info .side_effect = ValueError
1326
1328
mock_gpu_fallback .return_value = INSTANCE_GPU_INFO
1327
- mock_total_inference_model_size_mib .return_value = INSTANCE_GPU_INFO [1 ]/ INSTANCE_GPU_INFO [0 ] - 1
1329
+ mock_total_inference_model_size_mib .return_value = (
1330
+ INSTANCE_GPU_INFO [1 ] / INSTANCE_GPU_INFO [0 ] - 1
1331
+ )
1328
1332
1329
1333
mock_image_uris_retrieve .return_value = "https://some-image-uri"
1330
1334
1331
- model_builder = ModelBuilder (model = "stable-diffusion" , sagemaker_session = mock_session , instance_type = mock_instance_type )
1332
- self .assertEqual (model_builder ._try_fetch_gpu_info (), INSTANCE_GPU_INFO [1 ]/ INSTANCE_GPU_INFO [0 ])
1335
+ model_builder = ModelBuilder (
1336
+ model = "stable-diffusion" ,
1337
+ sagemaker_session = mock_session ,
1338
+ instance_type = mock_instance_type ,
1339
+ )
1340
+ self .assertEqual (
1341
+ model_builder ._try_fetch_gpu_info (), INSTANCE_GPU_INFO [1 ] / INSTANCE_GPU_INFO [0 ]
1342
+ )
1333
1343
self .assertEqual (model_builder ._can_fit_on_single_gpu (), True )
1334
1344
1335
1345
@patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers" , Mock ())
@@ -1343,16 +1353,16 @@ def test_build_for_transformers_happy_case_with_valid_gpu_fallback(
1343
1353
@patch ("sagemaker.model_uris.retrieve" )
1344
1354
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1345
1355
def test_build_for_transformers_happy_case_hugging_face_responses (
1346
- self ,
1347
- mock_serveSettings ,
1348
- mock_model_uris_retrieve ,
1349
- mock_llm_utils_json ,
1350
- mock_llm_utils_urllib ,
1351
- mock_model_json ,
1352
- mock_model_urllib ,
1353
- mock_image_uris_retrieve ,
1354
- mock_gather_data ,
1355
- mock_parser ,
1356
+ self ,
1357
+ mock_serveSettings ,
1358
+ mock_model_uris_retrieve ,
1359
+ mock_llm_utils_json ,
1360
+ mock_llm_utils_urllib ,
1361
+ mock_model_json ,
1362
+ mock_model_urllib ,
1363
+ mock_image_uris_retrieve ,
1364
+ mock_gather_data ,
1365
+ mock_parser ,
1356
1366
):
1357
1367
mock_setting_object = mock_serveSettings .return_value
1358
1368
mock_setting_object .role_arn = mock_role_arn
@@ -1370,12 +1380,20 @@ def test_build_for_transformers_happy_case_hugging_face_responses(
1370
1380
mock_gather_data .return_value = [[1 , 1 , 1 , 1 ]]
1371
1381
product = MIB_CONVERSION_FACTOR * 1 * MEMORY_BUFFER_MULTIPLIER
1372
1382
1373
- model_builder = ModelBuilder (model = "stable-diffusion" , sagemaker_session = mock_session , instance_type = mock_instance_type )
1383
+ model_builder = ModelBuilder (
1384
+ model = "stable-diffusion" ,
1385
+ sagemaker_session = mock_session ,
1386
+ instance_type = mock_instance_type ,
1387
+ )
1374
1388
self .assertEqual (model_builder ._total_inference_model_size_mib (), product )
1375
1389
1376
1390
mock_parser .return_value = Mock ()
1377
1391
mock_gather_data .return_value = None
1378
- model_builder = ModelBuilder (model = "stable-diffusion" , sagemaker_session = mock_session , instance_type = mock_instance_type )
1392
+ model_builder = ModelBuilder (
1393
+ model = "stable-diffusion" ,
1394
+ sagemaker_session = mock_session ,
1395
+ instance_type = mock_instance_type ,
1396
+ )
1379
1397
with self .assertRaises (ValueError ) as _ :
1380
1398
model_builder ._total_inference_model_size_mib ()
1381
1399
0 commit comments