Skip to content

Commit 38eed2f

Browse files
author
Jonathan Makunga
committed
Add Unit tests
1 parent 99d11cb commit 38eed2f

File tree

2 files changed

+93
-1
lines changed

2 files changed

+93
-1
lines changed

tests/integ/sagemaker/serve/test_schema_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(
6666
caught_ex = None
6767
try:
6868
iam_client = sagemaker_session.boto_session.client("iam")
69-
role_arn = iam_client.get_role(RoleName="JarvisTest")["Role"]["Arn"]
69+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
7070

7171
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
7272
predictor = model.deploy(

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from sagemaker.serve.builder.model_builder import ModelBuilder
2020
from sagemaker.serve.mode.function_pointers import Mode
21+
from sagemaker.serve.utils import task
22+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
2123
from sagemaker.serve.utils.types import ModelServer
2224
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
2325

@@ -985,3 +987,93 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
985987
build_result.deploy(mode=Mode.LOCAL_CONTAINER)
986988

987989
self.assertEqual(builder.mode, Mode.LOCAL_CONTAINER)
990+
991+
@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
992+
@patch("sagemaker.image_uris.retrieve")
993+
@patch("sagemaker.djl_inference.model.urllib")
994+
@patch("sagemaker.djl_inference.model.json")
995+
@patch("sagemaker.huggingface.llm_utils.urllib")
996+
@patch("sagemaker.huggingface.llm_utils.json")
997+
@patch("sagemaker.model_uris.retrieve")
998+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
999+
def test_build_happy_path_when_schema_builder_not_present(
1000+
self,
1001+
mock_serveSettings,
1002+
mock_model_uris_retrieve,
1003+
mock_llm_utils_json,
1004+
mock_llm_utils_urllib,
1005+
mock_model_json,
1006+
mock_model_urllib,
1007+
mock_image_uris_retrieve,
1008+
mock_hf_model,
1009+
):
1010+
# Setup mocks
1011+
1012+
mock_setting_object = mock_serveSettings.return_value
1013+
mock_setting_object.role_arn = mock_role_arn
1014+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1015+
1016+
# HF Pipeline Tag
1017+
mock_model_uris_retrieve.side_effect = KeyError
1018+
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"}
1019+
mock_llm_utils_urllib.request.Request.side_effect = Mock()
1020+
1021+
# HF Model config
1022+
mock_model_json.load.return_value = {"some": "config"}
1023+
mock_model_urllib.request.Request.side_effect = Mock()
1024+
1025+
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1026+
1027+
model_builder = ModelBuilder(model="meta-llama/Llama-2-7b-hf")
1028+
model_builder.build(sagemaker_session=mock_session)
1029+
1030+
self.assertIsNotNone(model_builder.schema_builder)
1031+
sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation")
1032+
self.assertEqual(
1033+
sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"]
1034+
)
1035+
self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output)
1036+
1037+
@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
1038+
@patch("sagemaker.image_uris.retrieve")
1039+
@patch("sagemaker.djl_inference.model.urllib")
1040+
@patch("sagemaker.djl_inference.model.json")
1041+
@patch("sagemaker.huggingface.llm_utils.urllib")
1042+
@patch("sagemaker.huggingface.llm_utils.json")
1043+
@patch("sagemaker.model_uris.retrieve")
1044+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
1045+
def test_build_negative_path_when_schema_builder_not_present(
1046+
self,
1047+
mock_serveSettings,
1048+
mock_model_uris_retrieve,
1049+
mock_llm_utils_json,
1050+
mock_llm_utils_urllib,
1051+
mock_model_json,
1052+
mock_model_urllib,
1053+
mock_image_uris_retrieve,
1054+
mock_hf_model,
1055+
):
1056+
# Setup mocks
1057+
1058+
mock_setting_object = mock_serveSettings.return_value
1059+
mock_setting_object.role_arn = mock_role_arn
1060+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1061+
1062+
# HF Pipeline Tag
1063+
mock_model_uris_retrieve.side_effect = KeyError
1064+
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-to-image"}
1065+
mock_llm_utils_urllib.request.Request.side_effect = Mock()
1066+
1067+
# HF Model config
1068+
mock_model_json.load.return_value = {"some": "config"}
1069+
mock_model_urllib.request.Request.side_effect = Mock()
1070+
1071+
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1072+
1073+
model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")
1074+
1075+
self.assertRaisesRegexp(
1076+
TaskNotFoundException,
1077+
"Error Message: Schema builder for text-to-image could not be found.",
1078+
lambda: model_builder.build(sagemaker_session=mock_session),
1079+
)

0 commit comments

Comments
 (0)