|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 |
| -from unittest.mock import MagicMock, patch |
| 14 | +from unittest.mock import MagicMock, patch, Mock |
15 | 15 |
|
16 | 16 | import unittest
|
17 | 17 |
|
| 18 | +from sagemaker.enums import Tag |
| 19 | +from sagemaker.serve import SchemaBuilder |
18 | 20 | from sagemaker.serve.builder.model_builder import ModelBuilder
|
19 | 21 | from sagemaker.serve.mode.function_pointers import Mode
|
20 | 22 | from sagemaker.serve.utils.exceptions import (
|
@@ -961,3 +963,119 @@ def test_display_benchmark_metrics_initial(
|
961 | 963 | builder.display_benchmark_metrics()
|
962 | 964 |
|
963 | 965 | mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once()
|
| 966 | + |
| 967 | + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) |
| 968 | + @patch( |
| 969 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", |
| 970 | + return_value=True, |
| 971 | + ) |
| 972 | + @patch( |
| 973 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", |
| 974 | + return_value=MagicMock(), |
| 975 | + ) |
| 976 | + @patch( |
| 977 | + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", |
| 978 | + return_value=({"model_type": "t5", "n_head": 71}, True), |
| 979 | + ) |
| 980 | + def test_fine_tuned_model_with_fine_tuning_model_path( |
| 981 | + self, |
| 982 | + mock_prepare_for_tgi, |
| 983 | + mock_pre_trained_model, |
| 984 | + mock_is_jumpstart_model, |
| 985 | + mock_telemetry, |
| 986 | + ): |
| 987 | + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri |
| 988 | + mock_fine_tuning_model_path = "s3://test" |
| 989 | + |
| 990 | + sample_input = { |
| 991 | + "inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish " |
| 992 | + "coastal tidal marshes of the", |
| 993 | + "parameters": {"max_new_tokens": 1024}, |
| 994 | + } |
| 995 | + sample_output = [ |
| 996 | + { |
| 997 | + "generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the " |
| 998 | + "brackish coastal tidal marshes of the east coast." |
| 999 | + } |
| 1000 | + ] |
| 1001 | + builder = ModelBuilder( |
| 1002 | + model="meta-textgeneration-llama-3-70b", |
| 1003 | + schema_builder=SchemaBuilder(sample_input, sample_output), |
| 1004 | + model_metadata={ |
| 1005 | + "FINE_TUNING_MODEL_PATH": mock_fine_tuning_model_path, |
| 1006 | + }, |
| 1007 | + ) |
| 1008 | + model = builder.build() |
| 1009 | + |
| 1010 | + model.model_data["S3DataSource"].__setitem__.assert_called_with( |
| 1011 | + "S3Uri", mock_fine_tuning_model_path |
| 1012 | + ) |
| 1013 | + mock_pre_trained_model.return_value.add_tags.assert_called_with( |
| 1014 | + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path} |
| 1015 | + ) |
| 1016 | + |
| 1017 | + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) |
| 1018 | + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) |
| 1019 | + @patch( |
| 1020 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", |
| 1021 | + return_value=True, |
| 1022 | + ) |
| 1023 | + @patch( |
| 1024 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", |
| 1025 | + return_value=MagicMock(), |
| 1026 | + ) |
| 1027 | + @patch( |
| 1028 | + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", |
| 1029 | + return_value=({"model_type": "t5", "n_head": 71}, True), |
| 1030 | + ) |
| 1031 | + def test_fine_tuned_model_with_fine_tuning_job_name( |
| 1032 | + self, |
| 1033 | + mock_prepare_for_tgi, |
| 1034 | + mock_pre_trained_model, |
| 1035 | + mock_is_jumpstart_model, |
| 1036 | + mock_serve_settings, |
| 1037 | + mock_telemetry, |
| 1038 | + ): |
| 1039 | + mock_fine_tuning_model_path = "s3://test" |
| 1040 | + mock_sagemaker_session = Mock() |
| 1041 | + mock_sagemaker_session.sagemaker_client.describe_training_job.return_value = { |
| 1042 | + "OutputDataConfig": { |
| 1043 | + "S3OutputPath": mock_fine_tuning_model_path, |
| 1044 | + "CompressionType": "None", |
| 1045 | + } |
| 1046 | + } |
| 1047 | + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri |
| 1048 | + mock_fine_tuning_job_name = "mock-job" |
| 1049 | + |
| 1050 | + sample_input = { |
| 1051 | + "inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish " |
| 1052 | + "coastal tidal marshes of the", |
| 1053 | + "parameters": {"max_new_tokens": 1024}, |
| 1054 | + } |
| 1055 | + sample_output = [ |
| 1056 | + { |
| 1057 | + "generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the " |
| 1058 | + "brackish coastal tidal marshes of the east coast." |
| 1059 | + } |
| 1060 | + ] |
| 1061 | + builder = ModelBuilder( |
| 1062 | + model="meta-textgeneration-llama-3-70b", |
| 1063 | + schema_builder=SchemaBuilder(sample_input, sample_output), |
| 1064 | + model_metadata={"FINE_TUNING_JOB_NAME": mock_fine_tuning_job_name}, |
| 1065 | + sagemaker_session=mock_sagemaker_session, |
| 1066 | + ) |
| 1067 | + model = builder.build(sagemaker_session=mock_sagemaker_session) |
| 1068 | + |
| 1069 | + mock_sagemaker_session.sagemaker_client.describe_training_job.assert_called_once_with( |
| 1070 | + TrainingJobName=mock_fine_tuning_job_name |
| 1071 | + ) |
| 1072 | + |
| 1073 | + model.model_data["S3DataSource"].__setitem__.assert_any_call( |
| 1074 | + "S3Uri", mock_fine_tuning_model_path |
| 1075 | + ) |
| 1076 | + mock_pre_trained_model.return_value.add_tags.assert_called_with( |
| 1077 | + [ |
| 1078 | + {"key": Tag.FINE_TUNING_JOB_NAME, "value": mock_fine_tuning_job_name}, |
| 1079 | + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path}, |
| 1080 | + ] |
| 1081 | + ) |
0 commit comments