12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
import unittest
15
- from unittest .mock import Mock , patch
15
+ from unittest .mock import Mock , patch , MagicMock
16
16
from sagemaker .serve import Mode , ModelServer
17
17
from sagemaker .serve .model_format .mlflow .constants import MLFLOW_MODEL_PATH
18
18
from sagemaker .serve .utils .telemetry_logger import (
25
25
from sagemaker .user_agent import SDK_VERSION
26
26
27
27
MOCK_SESSION = Mock ()
28
- MOCK_FUNC_NAME = "Mock.deploy"
28
+ MOCK_DEPLOY_FUNC_NAME = "Mock.deploy"
29
+ MOCK_OPTIMIZE_FUNC_NAME = "Mock.optimize"
29
30
MOCK_DJL_CONTAINER = (
30
31
"763104351884.dkr.ecr.us-west-2.amazonaws.com/" "djl-inference:0.25.0-deepspeed0.11.0-cu118"
31
32
)
@@ -47,11 +48,15 @@ def __init__(self):
47
48
self .serve_settings = Mock ()
48
49
self .sagemaker_session = MOCK_SESSION
49
50
50
- @_capture_telemetry (MOCK_FUNC_NAME )
51
+ @_capture_telemetry (MOCK_DEPLOY_FUNC_NAME )
51
52
def mock_deploy (self , mock_exception_func = None ):
52
53
if mock_exception_func :
53
54
mock_exception_func ()
54
55
56
+ @_capture_telemetry (MOCK_OPTIMIZE_FUNC_NAME )
57
+ def mock_optimize (self , * args , ** kwargs ):
58
+ pass
59
+
55
60
56
61
class TestTelemetryLogger (unittest .TestCase ):
57
62
@patch ("sagemaker.serve.utils.telemetry_logger._requests_helper" )
@@ -88,7 +93,7 @@ def test_capture_telemetry_decorator_djl_success(self, mock_send_telemetry):
88
93
args = mock_send_telemetry .call_args .args
89
94
latency = str (args [5 ]).split ("latency=" )[1 ]
90
95
expected_extra_str = (
91
- f"{ MOCK_FUNC_NAME } "
96
+ f"{ MOCK_DEPLOY_FUNC_NAME } "
92
97
"&x-modelServer=4"
93
98
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
94
99
f"&x-sdkVersion={ SDK_VERSION } "
@@ -118,7 +123,7 @@ def test_capture_telemetry_decorator_djl_success_with_custom_image(self, mock_se
118
123
args = mock_send_telemetry .call_args .args
119
124
latency = str (args [5 ]).split ("latency=" )[1 ]
120
125
expected_extra_str = (
121
- f"{ MOCK_FUNC_NAME } "
126
+ f"{ MOCK_DEPLOY_FUNC_NAME } "
122
127
"&x-modelServer=4"
123
128
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
124
129
f"&x-sdkVersion={ SDK_VERSION } "
@@ -148,7 +153,7 @@ def test_capture_telemetry_decorator_tgi_success(self, mock_send_telemetry):
148
153
args = mock_send_telemetry .call_args .args
149
154
latency = str (args [5 ]).split ("latency=" )[1 ]
150
155
expected_extra_str = (
151
- f"{ MOCK_FUNC_NAME } "
156
+ f"{ MOCK_DEPLOY_FUNC_NAME } "
152
157
"&x-modelServer=6"
153
158
"&x-imageTag=huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
154
159
f"&x-sdkVersion={ SDK_VERSION } "
@@ -196,7 +201,7 @@ def test_capture_telemetry_decorator_handle_exception_success(self, mock_send_te
196
201
args = mock_send_telemetry .call_args .args
197
202
latency = str (args [5 ]).split ("latency=" )[1 ]
198
203
expected_extra_str = (
199
- f"{ MOCK_FUNC_NAME } "
204
+ f"{ MOCK_DEPLOY_FUNC_NAME } "
200
205
"&x-modelServer=4"
201
206
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
202
207
f"&x-sdkVersion={ SDK_VERSION } "
@@ -243,7 +248,7 @@ def test_construct_url_with_failure_reason_and_extra_info(self):
243
248
f"&x-failureType={ mock_failure_type } "
244
249
f"&x-extra={ mock_extra_info } "
245
250
)
246
- self .assertEquals (ret_url , expected_base_url )
251
+ self .assertEqual (ret_url , expected_base_url )
247
252
248
253
@patch ("sagemaker.serve.utils.telemetry_logger._send_telemetry" )
249
254
def test_capture_telemetry_decorator_mlflow_success (self , mock_send_telemetry ):
@@ -262,7 +267,7 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry):
262
267
args = mock_send_telemetry .call_args .args
263
268
latency = str (args [5 ]).split ("latency=" )[1 ]
264
269
expected_extra_str = (
265
- f"{ MOCK_FUNC_NAME } "
270
+ f"{ MOCK_DEPLOY_FUNC_NAME } "
266
271
"&x-modelServer=1"
267
272
"&x-imageTag=pytorch-inference:2.0.1-cpu-py310"
268
273
f"&x-sdkVersion={ SDK_VERSION } "
@@ -275,3 +280,66 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry):
275
280
mock_send_telemetry .assert_called_once_with (
276
281
"1" , 3 , MOCK_SESSION , None , None , expected_extra_str
277
282
)
283
+
284
+ @patch ("sagemaker.serve.utils.telemetry_logger._send_telemetry" )
285
+ def test_capture_telemetry_decorator_optimize_with_default_configs (self , mock_send_telemetry ):
286
+ mock_model_builder = ModelBuilderMock ()
287
+ mock_model_builder .serve_settings .telemetry_opt_out = False
288
+ mock_model_builder .image_uri = None
289
+ mock_model_builder .mode = Mode .SAGEMAKER_ENDPOINT
290
+ mock_model_builder .model_server = ModelServer .TORCHSERVE
291
+ mock_model_builder .sagemaker_session .endpoint_arn = None
292
+
293
+ mock_model_builder .mock_optimize ()
294
+
295
+ args = mock_send_telemetry .call_args .args
296
+ latency = str (args [5 ]).split ("latency=" )[1 ]
297
+ expected_extra_str = (
298
+ f"{ MOCK_OPTIMIZE_FUNC_NAME } "
299
+ "&x-modelServer=1"
300
+ f"&x-sdkVersion={ SDK_VERSION } "
301
+ f"&x-latency={ latency } "
302
+ )
303
+
304
+ mock_send_telemetry .assert_called_once_with (
305
+ "1" , 3 , MOCK_SESSION , None , None , expected_extra_str
306
+ )
307
+
308
+ @patch ("sagemaker.serve.utils.telemetry_logger._send_telemetry" )
309
+ def test_capture_telemetry_decorator_optimize_with_custom_configs (self , mock_send_telemetry ):
310
+ mock_model_builder = ModelBuilderMock ()
311
+ mock_model_builder .serve_settings .telemetry_opt_out = False
312
+ mock_model_builder .image_uri = None
313
+ mock_model_builder .mode = Mode .SAGEMAKER_ENDPOINT
314
+ mock_model_builder .model_server = ModelServer .TORCHSERVE
315
+ mock_model_builder .sagemaker_session .endpoint_arn = None
316
+ mock_model_builder .is_fine_tuned = True
317
+ mock_model_builder .is_gated = True
318
+
319
+ mock_speculative_decoding_config = MagicMock ()
320
+ mock_config = {"ModelProvider" : "sagemaker" }
321
+ mock_speculative_decoding_config .__getitem__ .side_effect = mock_config .__getitem__
322
+
323
+ mock_model_builder .mock_optimize (
324
+ quantization_config = Mock (),
325
+ compilation_config = Mock (),
326
+ speculative_decoding_config = mock_speculative_decoding_config ,
327
+ )
328
+
329
+ args = mock_send_telemetry .call_args .args
330
+ latency = str (args [5 ]).split ("latency=" )[1 ]
331
+ expected_extra_str = (
332
+ f"{ MOCK_OPTIMIZE_FUNC_NAME } "
333
+ "&x-modelServer=1"
334
+ f"&x-sdkVersion={ SDK_VERSION } "
335
+ f"&x-fineTuned=1"
336
+ f"&x-gated=1"
337
+ f"&x-compiled=1"
338
+ f"&x-quantized=1"
339
+ f"&x-sdDraftModelSource=1"
340
+ f"&x-latency={ latency } "
341
+ )
342
+
343
+ mock_send_telemetry .assert_called_once_with (
344
+ "1" , 3 , MOCK_SESSION , None , None , expected_extra_str
345
+ )
0 commit comments