19
19
from sagemaker .serve .builder .model_builder import ModelBuilder
20
20
from sagemaker .serve .mode .function_pointers import Mode
21
21
from sagemaker .serve .utils .types import ModelServer
22
+ from sagemaker .huggingface .llm_utils import get_huggingface_model_metadata
22
23
23
24
schema_builder = MagicMock ()
24
25
mock_inference_spec = Mock ()
42
43
mock_s3_model_data_url = "sample s3 data url"
43
44
mock_secret_key = "mock_secret_key"
44
45
mock_instance_type = "mock instance type"
46
+ MOCK_HF_ID = "mock_hf_id"
47
+ MOCK_HF_HUB_TOKEN = "mock_hf_hub_token"
48
+ MOCK_HF_MODEL_METADATA_JSON = {"mock_key" : "mock_value" }
45
49
46
50
supported_model_server = {
47
51
ModelServer .TORCHSERVE ,
54
58
55
59
class TestModelBuilder (unittest .TestCase ):
56
60
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
57
- def test_validation_in_progress_mode_not_supported (self , mock_serveSettings ):
61
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
62
+ @patch ("sagemaker.huggingface.llm_utils.json" )
63
+ def test_validation_in_progress_mode_not_supported (
64
+ self , mock_serveSettings , mock_urllib , mock_json
65
+ ):
66
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
67
+ mock_hf_model_metadata_url = Mock ()
68
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
69
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
70
+
58
71
builder = ModelBuilder ()
59
72
self .assertRaisesRegex (
60
73
Exception ,
@@ -66,7 +79,16 @@ def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
66
79
)
67
80
68
81
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
69
- def test_validation_cannot_set_both_model_and_inference_spec (self , mock_serveSettings ):
82
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
83
+ @patch ("sagemaker.huggingface.llm_utils.json" )
84
+ def test_validation_cannot_set_both_model_and_inference_spec (
85
+ self , mock_serveSettings , mock_urllib , mock_json
86
+ ):
87
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
88
+ mock_hf_model_metadata_url = Mock ()
89
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
90
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
91
+
70
92
builder = ModelBuilder (inference_spec = "some value" , model = Mock (spec = object ))
71
93
self .assertRaisesRegex (
72
94
Exception ,
@@ -78,7 +100,16 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
78
100
)
79
101
80
102
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
81
- def test_validation_unsupported_model_server_type (self , mock_serveSettings ):
103
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
104
+ @patch ("sagemaker.huggingface.llm_utils.json" )
105
+ def test_validation_unsupported_model_server_type (
106
+ self , mock_serveSettings , mock_urllib , mock_json
107
+ ):
108
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
109
+ mock_hf_model_metadata_url = Mock ()
110
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
111
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
112
+
82
113
builder = ModelBuilder (model_server = "invalid_model_server" )
83
114
self .assertRaisesRegex (
84
115
Exception ,
@@ -91,7 +122,16 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
91
122
)
92
123
93
124
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
94
- def test_validation_model_server_not_set_with_image_uri (self , mock_serveSettings ):
125
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
126
+ @patch ("sagemaker.huggingface.llm_utils.json" )
127
+ def test_validation_model_server_not_set_with_image_uri (
128
+ self , mock_serveSettings , mock_urllib , mock_json
129
+ ):
130
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
131
+ mock_hf_model_metadata_url = Mock ()
132
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
133
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
134
+
95
135
builder = ModelBuilder (image_uri = "image_uri" )
96
136
self .assertRaisesRegex (
97
137
Exception ,
@@ -104,9 +144,16 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
104
144
)
105
145
106
146
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
147
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
148
+ @patch ("sagemaker.huggingface.llm_utils.json" )
107
149
def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set (
108
- self , mock_serveSettings
150
+ self , mock_serveSettings , mock_urllib , mock_json
109
151
):
152
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
153
+ mock_hf_model_metadata_url = Mock ()
154
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
155
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
156
+
110
157
builder = ModelBuilder (inference_spec = None , model = None )
111
158
self .assertRaisesRegex (
112
159
Exception ,
@@ -126,8 +173,12 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
126
173
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
127
174
@patch ("sagemaker.serve.builder.model_builder.Model" )
128
175
@patch ("os.path.exists" )
176
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
177
+ @patch ("sagemaker.huggingface.llm_utils.json" )
129
178
def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc (
130
179
self ,
180
+ mock_urllib ,
181
+ mock_json ,
131
182
mock_path_exists ,
132
183
mock_sdk_model ,
133
184
mock_sageMakerEndpointMode ,
@@ -146,6 +197,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
146
197
else None
147
198
)
148
199
200
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
201
+ mock_hf_model_metadata_url = Mock ()
202
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
203
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
204
+
149
205
mock_detect_fw_version .return_value = framework , version
150
206
151
207
mock_prepare_for_torchserve .side_effect = (
@@ -226,8 +282,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
226
282
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
227
283
@patch ("sagemaker.serve.builder.model_builder.Model" )
228
284
@patch ("os.path.exists" )
285
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
286
+ @patch ("sagemaker.huggingface.llm_utils.json" )
229
287
def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc (
230
288
self ,
289
+ mock_urllib ,
290
+ mock_json ,
231
291
mock_path_exists ,
232
292
mock_sdk_model ,
233
293
mock_sageMakerEndpointMode ,
@@ -246,6 +306,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
246
306
else None
247
307
)
248
308
309
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
310
+ mock_hf_model_metadata_url = Mock ()
311
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
312
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
313
+
249
314
mock_detect_fw_version .return_value = framework , version
250
315
251
316
mock_prepare_for_torchserve .side_effect = (
@@ -326,8 +391,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
326
391
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
327
392
@patch ("sagemaker.serve.builder.model_builder.Model" )
328
393
@patch ("os.path.exists" )
394
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
395
+ @patch ("sagemaker.huggingface.llm_utils.json" )
329
396
def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec (
330
397
self ,
398
+ mock_urllib ,
399
+ mock_json ,
331
400
mock_path_exists ,
332
401
mock_sdk_model ,
333
402
mock_sageMakerEndpointMode ,
@@ -343,6 +412,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
343
412
lambda model_path : mock_native_model if model_path == MODEL_PATH else None
344
413
)
345
414
415
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
416
+ mock_hf_model_metadata_url = Mock ()
417
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
418
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
419
+
346
420
mock_detect_fw_version .return_value = framework , version
347
421
348
422
mock_detect_container .side_effect = (
@@ -427,8 +501,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
427
501
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
428
502
@patch ("sagemaker.serve.builder.model_builder.Model" )
429
503
@patch ("os.path.exists" )
504
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
505
+ @patch ("sagemaker.huggingface.llm_utils.json" )
430
506
def test_build_happy_path_with_sagemakerEndpoint_mode_and_model (
431
507
self ,
508
+ mock_urllib ,
509
+ mock_json ,
432
510
mock_path_exists ,
433
511
mock_sdk_model ,
434
512
mock_sageMakerEndpointMode ,
@@ -447,6 +525,11 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
447
525
else None
448
526
)
449
527
528
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
529
+ mock_hf_model_metadata_url = Mock ()
530
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
531
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
532
+
450
533
mock_detect_fw_version .return_value = framework , version
451
534
452
535
mock_prepare_for_torchserve .side_effect = (
@@ -530,8 +613,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
530
613
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
531
614
@patch ("sagemaker.serve.builder.model_builder.Model" )
532
615
@patch ("os.path.exists" )
616
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
617
+ @patch ("sagemaker.huggingface.llm_utils.json" )
533
618
def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model (
534
619
self ,
620
+ mock_urllib ,
621
+ mock_json ,
535
622
mock_path_exists ,
536
623
mock_sdk_model ,
537
624
mock_sageMakerEndpointMode ,
@@ -551,6 +638,11 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
551
638
else None
552
639
)
553
640
641
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
642
+ mock_hf_model_metadata_url = Mock ()
643
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
644
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
645
+
554
646
mock_detect_fw_version .return_value = "xgboost" , version
555
647
556
648
mock_prepare_for_torchserve .side_effect = (
@@ -635,8 +727,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
635
727
@patch ("sagemaker.serve.builder.model_builder.LocalContainerMode" )
636
728
@patch ("sagemaker.serve.builder.model_builder.Model" )
637
729
@patch ("os.path.exists" )
730
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
731
+ @patch ("sagemaker.huggingface.llm_utils.json" )
638
732
def test_build_happy_path_with_local_container_mode (
639
733
self ,
734
+ mock_urllib ,
735
+ mock_json ,
640
736
mock_path_exists ,
641
737
mock_sdk_model ,
642
738
mock_localContainerMode ,
@@ -651,6 +747,11 @@ def test_build_happy_path_with_local_container_mode(
651
747
lambda model_path : mock_native_model if model_path == MODEL_PATH else None
652
748
)
653
749
750
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
751
+ mock_hf_model_metadata_url = Mock ()
752
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
753
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
754
+
654
755
mock_detect_container .side_effect = (
655
756
lambda model , region , instance_type : mock_image_uri
656
757
if model == mock_native_model
@@ -729,8 +830,12 @@ def test_build_happy_path_with_local_container_mode(
729
830
@patch ("sagemaker.serve.builder.model_builder.LocalContainerMode" )
730
831
@patch ("sagemaker.serve.builder.model_builder.Model" )
731
832
@patch ("os.path.exists" )
833
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
834
+ @patch ("sagemaker.huggingface.llm_utils.json" )
732
835
def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mode (
733
836
self ,
837
+ mock_urllib ,
838
+ mock_json ,
734
839
mock_path_exists ,
735
840
mock_sdk_model ,
736
841
mock_localContainerMode ,
@@ -747,6 +852,11 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
747
852
lambda model_path : mock_native_model if model_path == MODEL_PATH else None
748
853
)
749
854
855
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
856
+ mock_hf_model_metadata_url = Mock ()
857
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
858
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
859
+
750
860
mock_detect_fw_version .return_value = framework , version
751
861
752
862
mock_detect_container .side_effect = (
@@ -870,8 +980,12 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
870
980
@patch ("sagemaker.serve.builder.model_builder.LocalContainerMode" )
871
981
@patch ("sagemaker.serve.builder.model_builder.Model" )
872
982
@patch ("os.path.exists" )
983
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
984
+ @patch ("sagemaker.huggingface.llm_utils.json" )
873
985
def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_container (
874
986
self ,
987
+ mock_urllib ,
988
+ mock_json ,
875
989
mock_path_exists ,
876
990
mock_sdk_model ,
877
991
mock_localContainerMode ,
@@ -885,6 +999,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
885
999
# setup mocks
886
1000
mock_detect_fw_version .return_value = framework , version
887
1001
1002
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
1003
+ mock_hf_model_metadata_url = Mock ()
1004
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
1005
+ ret_json = get_huggingface_model_metadata (MOCK_HF_ID , MOCK_HF_HUB_TOKEN )
1006
+
888
1007
mock_detect_container .side_effect = (
889
1008
lambda model , region , instance_type : mock_image_uri
890
1009
if model == mock_fw_model
0 commit comments