Skip to content

Commit b07f210

Browse files
grenmesterJacky Lee
andauthored
unit: tests for fine tuned JS model support (aws#1481)
* UTs * flake8 --------- Co-authored-by: Jacky Lee <[email protected]>
1 parent 1f6f876 commit b07f210

File tree

1 file changed

+119
-1
lines changed

1 file changed

+119
-1
lines changed

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14-
from unittest.mock import MagicMock, patch
14+
from unittest.mock import MagicMock, patch, Mock
1515

1616
import unittest
1717

18+
from sagemaker.enums import Tag
19+
from sagemaker.serve import SchemaBuilder
1820
from sagemaker.serve.builder.model_builder import ModelBuilder
1921
from sagemaker.serve.mode.function_pointers import Mode
2022
from sagemaker.serve.utils.exceptions import (
@@ -961,3 +963,119 @@ def test_display_benchmark_metrics_initial(
961963
builder.display_benchmark_metrics()
962964

963965
mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once()
966+
967+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
968+
@patch(
969+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
970+
return_value=True,
971+
)
972+
@patch(
973+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
974+
return_value=MagicMock(),
975+
)
976+
@patch(
977+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
978+
return_value=({"model_type": "t5", "n_head": 71}, True),
979+
)
980+
def test_fine_tuned_model_with_fine_tuning_model_path(
981+
self,
982+
mock_prepare_for_tgi,
983+
mock_pre_trained_model,
984+
mock_is_jumpstart_model,
985+
mock_telemetry,
986+
):
987+
mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri
988+
mock_fine_tuning_model_path = "s3://test"
989+
990+
sample_input = {
991+
"inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish "
992+
"coastal tidal marshes of the",
993+
"parameters": {"max_new_tokens": 1024},
994+
}
995+
sample_output = [
996+
{
997+
"generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the "
998+
"brackish coastal tidal marshes of the east coast."
999+
}
1000+
]
1001+
builder = ModelBuilder(
1002+
model="meta-textgeneration-llama-3-70b",
1003+
schema_builder=SchemaBuilder(sample_input, sample_output),
1004+
model_metadata={
1005+
"FINE_TUNING_MODEL_PATH": mock_fine_tuning_model_path,
1006+
},
1007+
)
1008+
model = builder.build()
1009+
1010+
model.model_data["S3DataSource"].__setitem__.assert_called_with(
1011+
"S3Uri", mock_fine_tuning_model_path
1012+
)
1013+
mock_pre_trained_model.return_value.add_tags.assert_called_with(
1014+
{"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path}
1015+
)
1016+
1017+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
1018+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1019+
@patch(
1020+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
1021+
return_value=True,
1022+
)
1023+
@patch(
1024+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
1025+
return_value=MagicMock(),
1026+
)
1027+
@patch(
1028+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
1029+
return_value=({"model_type": "t5", "n_head": 71}, True),
1030+
)
1031+
def test_fine_tuned_model_with_fine_tuning_job_name(
1032+
self,
1033+
mock_prepare_for_tgi,
1034+
mock_pre_trained_model,
1035+
mock_is_jumpstart_model,
1036+
mock_serve_settings,
1037+
mock_telemetry,
1038+
):
1039+
mock_fine_tuning_model_path = "s3://test"
1040+
mock_sagemaker_session = Mock()
1041+
mock_sagemaker_session.sagemaker_client.describe_training_job.return_value = {
1042+
"OutputDataConfig": {
1043+
"S3OutputPath": mock_fine_tuning_model_path,
1044+
"CompressionType": "None",
1045+
}
1046+
}
1047+
mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri
1048+
mock_fine_tuning_job_name = "mock-job"
1049+
1050+
sample_input = {
1051+
"inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish "
1052+
"coastal tidal marshes of the",
1053+
"parameters": {"max_new_tokens": 1024},
1054+
}
1055+
sample_output = [
1056+
{
1057+
"generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the "
1058+
"brackish coastal tidal marshes of the east coast."
1059+
}
1060+
]
1061+
builder = ModelBuilder(
1062+
model="meta-textgeneration-llama-3-70b",
1063+
schema_builder=SchemaBuilder(sample_input, sample_output),
1064+
model_metadata={"FINE_TUNING_JOB_NAME": mock_fine_tuning_job_name},
1065+
sagemaker_session=mock_sagemaker_session,
1066+
)
1067+
model = builder.build(sagemaker_session=mock_sagemaker_session)
1068+
1069+
mock_sagemaker_session.sagemaker_client.describe_training_job.assert_called_once_with(
1070+
TrainingJobName=mock_fine_tuning_job_name
1071+
)
1072+
1073+
model.model_data["S3DataSource"].__setitem__.assert_any_call(
1074+
"S3Uri", mock_fine_tuning_model_path
1075+
)
1076+
mock_pre_trained_model.return_value.add_tags.assert_called_with(
1077+
[
1078+
{"key": Tag.FINE_TUNING_JOB_NAME, "value": mock_fine_tuning_job_name},
1079+
{"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path},
1080+
]
1081+
)

0 commit comments

Comments
 (0)