20
20
21
21
from sagemaker .serve .utils .predictors import TeiLocalModePredictor
22
22
23
- mock_model_id = "bert-base-uncased"
24
- mock_prompt = "The man worked as a [MASK]."
25
- mock_sample_input = {"inputs" : mock_prompt }
26
- mock_sample_output = [
23
+ MOCK_MODEL_ID = "bert-base-uncased"
24
+ MOCK_PROMPT = "The man worked as a [MASK]."
25
+ MOCK_SAMPLE_INPUT = {"inputs" : MOCK_PROMPT }
26
+ MOCK_SAMPLE_OUTPUT = [
27
27
{
28
28
"score" : 0.0974755585193634 ,
29
29
"token" : 10533 ,
55
55
"sequence" : "the man worked as a salesman." ,
56
56
},
57
57
]
58
- mock_schema_builder = MagicMock ()
59
- mock_schema_builder .sample_input = mock_sample_input
60
- mock_schema_builder .sample_output = mock_sample_output
58
+ MOCK_SCHEMA_BUILDER = MagicMock ()
59
+ MOCK_SCHEMA_BUILDER .sample_input = MOCK_SAMPLE_INPUT
60
+ MOCK_SCHEMA_BUILDER .sample_output = MOCK_SAMPLE_OUTPUT
61
61
MOCK_IMAGE_CONFIG = (
62
62
"763104351884.dkr.ecr.us-west-2.amazonaws.com/"
63
63
"huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0"
64
64
)
65
+ MOCK_MODEL_PATH = "mock model path"
65
66
66
67
67
68
class TestTEIBuilder (unittest .TestCase ):
@@ -70,57 +71,136 @@ class TestTEIBuilder(unittest.TestCase):
70
71
return_value = "ml.g5.24xlarge" ,
71
72
)
72
73
@patch ("sagemaker.serve.builder.tei_builder._capture_telemetry" , side_effect = None )
73
- def test_build_deploy_for_tei_local_container_and_remote_container (
74
+ def test_tei_builder_sagemaker_endpoint_mode_no_s3_upload_success (
74
75
self ,
75
76
mock_get_nb_instance ,
76
77
mock_telemetry ,
77
78
):
79
+ # verify SAGEMAKER_ENDPOINT deploy
78
80
builder = ModelBuilder (
79
- model = mock_model_id ,
80
- schema_builder = mock_schema_builder ,
81
+ model = MOCK_MODEL_ID ,
82
+ schema_builder = MOCK_SCHEMA_BUILDER ,
83
+ mode = Mode .SAGEMAKER_ENDPOINT ,
84
+ model_metadata = {
85
+ "HF_TASK" : "sentence-similarity" ,
86
+ },
87
+ )
88
+
89
+ builder ._prepare_for_mode = MagicMock ()
90
+ builder ._prepare_for_mode .return_value = (None , {})
91
+ model = builder .build ()
92
+ builder .serve_settings .telemetry_opt_out = True
93
+ builder ._original_deploy = MagicMock ()
94
+
95
+ model .deploy (mode = Mode .SAGEMAKER_ENDPOINT , role = "mock_role_arn" )
96
+
97
+ assert "HF_MODEL_ID" in model .env
98
+ with self .assertRaises (ValueError ) as _ :
99
+ model .deploy (mode = Mode .IN_PROCESS )
100
+ builder ._prepare_for_mode .assert_called_with ()
101
+
102
+ @patch (
103
+ "sagemaker.serve.builder.tei_builder._get_nb_instance" ,
104
+ return_value = "ml.g5.24xlarge" ,
105
+ )
106
+ @patch ("sagemaker.serve.builder.tei_builder._capture_telemetry" , side_effect = None )
107
+ def test_tei_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoint_success (
108
+ self ,
109
+ mock_get_nb_instance ,
110
+ mock_telemetry ,
111
+ ):
112
+ # verify LOCAL_CONTAINER deploy
113
+ builder = ModelBuilder (
114
+ model = MOCK_MODEL_ID ,
115
+ schema_builder = MOCK_SCHEMA_BUILDER ,
81
116
mode = Mode .LOCAL_CONTAINER ,
82
117
vpc_config = MOCK_VPC_CONFIG ,
83
118
model_metadata = {
84
119
"HF_TASK" : "sentence-similarity" ,
85
120
},
121
+ model_path = MOCK_MODEL_PATH ,
86
122
)
87
123
88
124
builder ._prepare_for_mode = MagicMock ()
89
125
builder ._prepare_for_mode .side_effect = None
90
-
91
126
model = builder .build ()
92
127
builder .serve_settings .telemetry_opt_out = True
93
-
94
128
builder .modes [str (Mode .LOCAL_CONTAINER )] = MagicMock ()
129
+
95
130
predictor = model .deploy (model_data_download_timeout = 1800 )
96
131
97
132
assert model .vpc_config == MOCK_VPC_CONFIG
98
133
assert builder .env_vars ["MODEL_LOADING_TIMEOUT" ] == "1800"
99
134
assert isinstance (predictor , TeiLocalModePredictor )
100
-
101
135
assert builder .nb_instance_type == "ml.g5.24xlarge"
102
136
137
+ # verify SAGEMAKER_ENDPOINT overwritten deploy
103
138
builder ._original_deploy = MagicMock ()
104
139
builder ._prepare_for_mode .return_value = (None , {})
105
- predictor = model .deploy (mode = Mode .SAGEMAKER_ENDPOINT , role = "mock_role_arn" )
106
- assert "HF_MODEL_ID" in model .env
107
140
141
+ model .deploy (mode = Mode .SAGEMAKER_ENDPOINT , role = "mock_role_arn" )
142
+
143
+ assert "HF_MODEL_ID" in model .env
108
144
with self .assertRaises (ValueError ) as _ :
109
145
model .deploy (mode = Mode .IN_PROCESS )
146
+ builder ._prepare_for_mode .call_args_list [1 ].assert_called_once_with (
147
+ model_path = MOCK_MODEL_PATH , should_upload_artifacts = True
148
+ )
149
+
150
+ @patch (
151
+ "sagemaker.serve.builder.tei_builder._get_nb_instance" ,
152
+ return_value = "ml.g5.24xlarge" ,
153
+ )
154
+ @patch ("sagemaker.serve.builder.tei_builder._capture_telemetry" , side_effect = None )
155
+ @patch ("sagemaker.serve.builder.tei_builder._is_optimized" , return_value = True )
156
+ def test_tei_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success (
157
+ self ,
158
+ mock_is_optimized ,
159
+ mock_get_nb_instance ,
160
+ mock_telemetry ,
161
+ ):
162
+ # verify LOCAL_CONTAINER deploy
163
+ builder = ModelBuilder (
164
+ model = MOCK_MODEL_ID ,
165
+ schema_builder = MOCK_SCHEMA_BUILDER ,
166
+ mode = Mode .LOCAL_CONTAINER ,
167
+ vpc_config = MOCK_VPC_CONFIG ,
168
+ model_metadata = {
169
+ "HF_TASK" : "sentence-similarity" ,
170
+ },
171
+ model_path = MOCK_MODEL_PATH ,
172
+ )
173
+
174
+ builder ._prepare_for_mode = MagicMock ()
175
+ builder ._prepare_for_mode .side_effect = None
176
+ model = builder .build ()
177
+ builder .serve_settings .telemetry_opt_out = True
178
+ builder .modes [str (Mode .LOCAL_CONTAINER )] = MagicMock ()
179
+
180
+ model .deploy (model_data_download_timeout = 1800 )
181
+
182
+ # verify SAGEMAKER_ENDPOINT overwritten deploy
183
+ builder ._original_deploy = MagicMock ()
184
+ builder ._prepare_for_mode .return_value = (None , {})
185
+
186
+ model .deploy (mode = Mode .SAGEMAKER_ENDPOINT , role = "mock_role_arn" )
187
+
188
+ # verify that if optimized, no s3 upload occurs
189
+ builder ._prepare_for_mode .assert_called_with ()
110
190
111
191
@patch (
112
192
"sagemaker.serve.builder.tei_builder._get_nb_instance" ,
113
193
return_value = "ml.g5.24xlarge" ,
114
194
)
115
195
@patch ("sagemaker.serve.builder.tei_builder._capture_telemetry" , side_effect = None )
116
- def test_image_uri_override (
196
+ def test_tei_builder_image_uri_override_success (
117
197
self ,
118
198
mock_get_nb_instance ,
119
199
mock_telemetry ,
120
200
):
121
201
builder = ModelBuilder (
122
- model = mock_model_id ,
123
- schema_builder = mock_schema_builder ,
202
+ model = MOCK_MODEL_ID ,
203
+ schema_builder = MOCK_SCHEMA_BUILDER ,
124
204
mode = Mode .LOCAL_CONTAINER ,
125
205
image_uri = MOCK_IMAGE_CONFIG ,
126
206
model_metadata = {
0 commit comments