Skip to content

Commit faa93e5

Browse files
author
EC2 Default User
committed
Fix build
1 parent 617dc82 commit faa93e5

File tree

3 files changed

+125
-9
lines changed

3 files changed

+125
-9
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
ModelServer.TORCHSERVE,
6363
ModelServer.TRITON,
6464
ModelServer.DJL_SERVING,
65-
ModelServer.MMS,
6665
}
6766

6867

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 124 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.serve.builder.model_builder import ModelBuilder
2020
from sagemaker.serve.mode.function_pointers import Mode
2121
from sagemaker.serve.utils.types import ModelServer
22+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
2223

2324
schema_builder = MagicMock()
2425
mock_inference_spec = Mock()
@@ -42,6 +43,9 @@
4243
mock_s3_model_data_url = "sample s3 data url"
4344
mock_secret_key = "mock_secret_key"
4445
mock_instance_type = "mock instance type"
46+
MOCK_HF_ID = "mock_hf_id"
47+
MOCK_HF_HUB_TOKEN = "mock_hf_hub_token"
48+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
4549

4650
supported_model_server = {
4751
ModelServer.TORCHSERVE,
@@ -54,7 +58,16 @@
5458

5559
class TestModelBuilder(unittest.TestCase):
5660
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
57-
def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
61+
@patch("sagemaker.huggingface.llm_utils.urllib")
62+
@patch("sagemaker.huggingface.llm_utils.json")
63+
def test_validation_in_progress_mode_not_supported(
64+
self, mock_serveSettings, mock_urllib, mock_json
65+
):
66+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
67+
mock_hf_model_metadata_url = Mock()
68+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
69+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
70+
5871
builder = ModelBuilder()
5972
self.assertRaisesRegex(
6073
Exception,
@@ -66,7 +79,16 @@ def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
6679
)
6780

6881
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
69-
def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSettings):
82+
@patch("sagemaker.huggingface.llm_utils.urllib")
83+
@patch("sagemaker.huggingface.llm_utils.json")
84+
def test_validation_cannot_set_both_model_and_inference_spec(
85+
self, mock_serveSettings, mock_urllib, mock_json
86+
):
87+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
88+
mock_hf_model_metadata_url = Mock()
89+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
90+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
91+
7092
builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object))
7193
self.assertRaisesRegex(
7294
Exception,
@@ -78,7 +100,16 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
78100
)
79101

80102
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
81-
def test_validation_unsupported_model_server_type(self, mock_serveSettings):
103+
@patch("sagemaker.huggingface.llm_utils.urllib")
104+
@patch("sagemaker.huggingface.llm_utils.json")
105+
def test_validation_unsupported_model_server_type(
106+
self, mock_serveSettings, mock_urllib, mock_json
107+
):
108+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
109+
mock_hf_model_metadata_url = Mock()
110+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
111+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
112+
82113
builder = ModelBuilder(model_server="invalid_model_server")
83114
self.assertRaisesRegex(
84115
Exception,
@@ -91,7 +122,16 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
91122
)
92123

93124
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
94-
def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings):
125+
@patch("sagemaker.huggingface.llm_utils.urllib")
126+
@patch("sagemaker.huggingface.llm_utils.json")
127+
def test_validation_model_server_not_set_with_image_uri(
128+
self, mock_serveSettings, mock_urllib, mock_json
129+
):
130+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
131+
mock_hf_model_metadata_url = Mock()
132+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
133+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
134+
95135
builder = ModelBuilder(image_uri="image_uri")
96136
self.assertRaisesRegex(
97137
Exception,
@@ -104,9 +144,16 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
104144
)
105145

106146
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
147+
@patch("sagemaker.huggingface.llm_utils.urllib")
148+
@patch("sagemaker.huggingface.llm_utils.json")
107149
def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set(
108-
self, mock_serveSettings
150+
self, mock_serveSettings, mock_urllib, mock_json
109151
):
152+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
153+
mock_hf_model_metadata_url = Mock()
154+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
155+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
156+
110157
builder = ModelBuilder(inference_spec=None, model=None)
111158
self.assertRaisesRegex(
112159
Exception,
@@ -126,8 +173,12 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
126173
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
127174
@patch("sagemaker.serve.builder.model_builder.Model")
128175
@patch("os.path.exists")
176+
@patch("sagemaker.huggingface.llm_utils.urllib")
177+
@patch("sagemaker.huggingface.llm_utils.json")
129178
def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
130179
self,
180+
mock_urllib,
181+
mock_json,
131182
mock_path_exists,
132183
mock_sdk_model,
133184
mock_sageMakerEndpointMode,
@@ -146,6 +197,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
146197
else None
147198
)
148199

200+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
201+
mock_hf_model_metadata_url = Mock()
202+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
203+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
204+
149205
mock_detect_fw_version.return_value = framework, version
150206

151207
mock_prepare_for_torchserve.side_effect = (
@@ -226,8 +282,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
226282
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
227283
@patch("sagemaker.serve.builder.model_builder.Model")
228284
@patch("os.path.exists")
285+
@patch("sagemaker.huggingface.llm_utils.urllib")
286+
@patch("sagemaker.huggingface.llm_utils.json")
229287
def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
230288
self,
289+
mock_urllib,
290+
mock_json,
231291
mock_path_exists,
232292
mock_sdk_model,
233293
mock_sageMakerEndpointMode,
@@ -246,6 +306,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
246306
else None
247307
)
248308

309+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
310+
mock_hf_model_metadata_url = Mock()
311+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
312+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
313+
249314
mock_detect_fw_version.return_value = framework, version
250315

251316
mock_prepare_for_torchserve.side_effect = (
@@ -326,8 +391,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
326391
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
327392
@patch("sagemaker.serve.builder.model_builder.Model")
328393
@patch("os.path.exists")
394+
@patch("sagemaker.huggingface.llm_utils.urllib")
395+
@patch("sagemaker.huggingface.llm_utils.json")
329396
def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
330397
self,
398+
mock_urllib,
399+
mock_json,
331400
mock_path_exists,
332401
mock_sdk_model,
333402
mock_sageMakerEndpointMode,
@@ -343,6 +412,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
343412
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
344413
)
345414

415+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
416+
mock_hf_model_metadata_url = Mock()
417+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
418+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
419+
346420
mock_detect_fw_version.return_value = framework, version
347421

348422
mock_detect_container.side_effect = (
@@ -427,8 +501,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
427501
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
428502
@patch("sagemaker.serve.builder.model_builder.Model")
429503
@patch("os.path.exists")
504+
@patch("sagemaker.huggingface.llm_utils.urllib")
505+
@patch("sagemaker.huggingface.llm_utils.json")
430506
def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
431507
self,
508+
mock_urllib,
509+
mock_json,
432510
mock_path_exists,
433511
mock_sdk_model,
434512
mock_sageMakerEndpointMode,
@@ -447,6 +525,11 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
447525
else None
448526
)
449527

528+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
529+
mock_hf_model_metadata_url = Mock()
530+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
531+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
532+
450533
mock_detect_fw_version.return_value = framework, version
451534

452535
mock_prepare_for_torchserve.side_effect = (
@@ -530,8 +613,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
530613
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
531614
@patch("sagemaker.serve.builder.model_builder.Model")
532615
@patch("os.path.exists")
616+
@patch("sagemaker.huggingface.llm_utils.urllib")
617+
@patch("sagemaker.huggingface.llm_utils.json")
533618
def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
534619
self,
620+
mock_urllib,
621+
mock_json,
535622
mock_path_exists,
536623
mock_sdk_model,
537624
mock_sageMakerEndpointMode,
@@ -551,6 +638,11 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
551638
else None
552639
)
553640

641+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
642+
mock_hf_model_metadata_url = Mock()
643+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
644+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
645+
554646
mock_detect_fw_version.return_value = "xgboost", version
555647

556648
mock_prepare_for_torchserve.side_effect = (
@@ -635,8 +727,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
635727
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
636728
@patch("sagemaker.serve.builder.model_builder.Model")
637729
@patch("os.path.exists")
730+
@patch("sagemaker.huggingface.llm_utils.urllib")
731+
@patch("sagemaker.huggingface.llm_utils.json")
638732
def test_build_happy_path_with_local_container_mode(
639733
self,
734+
mock_urllib,
735+
mock_json,
640736
mock_path_exists,
641737
mock_sdk_model,
642738
mock_localContainerMode,
@@ -651,6 +747,11 @@ def test_build_happy_path_with_local_container_mode(
651747
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
652748
)
653749

750+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
751+
mock_hf_model_metadata_url = Mock()
752+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
753+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
754+
654755
mock_detect_container.side_effect = (
655756
lambda model, region, instance_type: mock_image_uri
656757
if model == mock_native_model
@@ -729,8 +830,12 @@ def test_build_happy_path_with_local_container_mode(
729830
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
730831
@patch("sagemaker.serve.builder.model_builder.Model")
731832
@patch("os.path.exists")
833+
@patch("sagemaker.huggingface.llm_utils.urllib")
834+
@patch("sagemaker.huggingface.llm_utils.json")
732835
def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mode(
733836
self,
837+
mock_urllib,
838+
mock_json,
734839
mock_path_exists,
735840
mock_sdk_model,
736841
mock_localContainerMode,
@@ -747,6 +852,11 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
747852
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
748853
)
749854

855+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
856+
mock_hf_model_metadata_url = Mock()
857+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
858+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
859+
750860
mock_detect_fw_version.return_value = framework, version
751861

752862
mock_detect_container.side_effect = (
@@ -870,8 +980,12 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
870980
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
871981
@patch("sagemaker.serve.builder.model_builder.Model")
872982
@patch("os.path.exists")
983+
@patch("sagemaker.huggingface.llm_utils.urllib")
984+
@patch("sagemaker.huggingface.llm_utils.json")
873985
def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_container(
874986
self,
987+
mock_urllib,
988+
mock_json,
875989
mock_path_exists,
876990
mock_sdk_model,
877991
mock_localContainerMode,
@@ -885,6 +999,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
885999
# setup mocks
8861000
mock_detect_fw_version.return_value = framework, version
8871001

1002+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
1003+
mock_hf_model_metadata_url = Mock()
1004+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
1005+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
1006+
8881007
mock_detect_container.side_effect = (
8891008
lambda model, region, instance_type: mock_image_uri
8901009
if model == mock_fw_model

tests/unit/sagemaker/serve/builder/test_transformers_builder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,14 @@
6060

6161

6262
class TestTransformersBuilder(unittest.TestCase):
63-
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
64-
@patch("sagemaker.serve.builder.transformers_builder._get_ram_usage_mb", return_value=1024)
6563
@patch(
6664
"sagemaker.serve.builder.transformers_builder._get_nb_instance",
6765
return_value="ml.g5.24xlarge",
6866
)
67+
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
6968
def test_build_deploy_for_transformers_local_container_and_remote_container(
7069
self,
7170
mock_get_nb_instance,
72-
mock_get_ram_usage_mb,
7371
mock_telemetry,
7472
):
7573
builder = ModelBuilder(

0 commit comments

Comments
 (0)