Skip to content

Commit 1f6f876

Browse files
grenmesterJacky Lee
andauthored
feat: add quicksilver telemetry (aws#1482)
* feat: add quicksilver telemetry fields * pylint * add UTs * pylint * Refactor * add gated and fine-tuned to telemetry * fix: typo * fix: jumpstart var * refactor model_hub * pylint * update TEI/TGI to remove jumpstart field * reorder telemetry schema * refactor --------- Co-authored-by: Jacky Lee <[email protected]>
1 parent f3b3504 commit 1f6f876

File tree

8 files changed

+156
-23
lines changed

8 files changed

+156
-23
lines changed

src/sagemaker/enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ class Tag(str, Enum):
4646
"""Enum class for tag keys to apply to models."""
4747

4848
OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name"
49-
SPECULATIVE_DRAFT_MODL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider"
49+
SPECULATIVE_DRAFT_MODEL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider"
5050
FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path"
5151
FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name"

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ def __init__(self):
108108
self.schema_builder = None
109109
self.nb_instance_type = None
110110
self.ram_usage_model_load = None
111-
self.jumpstart = None
111+
self.model_hub = None
112112
self.model_metadata = None
113+
self.is_fine_tuned = None
114+
self.is_gated = None
113115

114116
@abstractmethod
115117
def _prepare_for_mode(self):
@@ -580,14 +582,14 @@ def _build_for_jumpstart(self):
580582

581583
# we do not pickle for jumpstart. set to none
582584
self.secret_key = None
583-
self.jumpstart = True
584585

585586
pysdk_model = self._create_pre_trained_js_model()
586587
image_uri = pysdk_model.image_uri
587588

588589
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
589590

590591
if self._is_fine_tuned_model():
592+
self.is_fine_tuned = True
591593
pysdk_model = self._update_model_data_for_fine_tuned_model(pysdk_model)
592594

593595
if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
@@ -754,8 +756,10 @@ def _is_gated_model(self, model=None) -> bool:
754756
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
755757

756758
if s3_uri is None:
757-
return False
758-
return "private" in s3_uri
759+
self.is_gated = False
760+
else:
761+
self.is_gated = "private" in s3_uri
762+
return self.is_gated
759763

760764
def _set_additional_model_source(
761765
self,
@@ -792,7 +796,7 @@ def _set_additional_model_source(
792796
)
793797

794798
self.pysdk_model.add_tags(
795-
{"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "sagemaker"},
799+
{"key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "value": "sagemaker"},
796800
)
797801
else:
798802
s3_uri = speculative_decoding_config.get("ModelSource")
@@ -811,7 +815,7 @@ def _set_additional_model_source(
811815

812816
self.pysdk_model.additional_model_data_sources = [additional_model_data_source]
813817
self.pysdk_model.add_tags(
814-
{"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "customer"},
818+
{"key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "value": "customer"},
815819
)
816820

817821
def _find_compatible_deployment_config(

src/sagemaker/serve/builder/model_builder.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve
7979
from sagemaker.serve.model_server.triton.triton_builder import Triton
8080
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
81-
from sagemaker.serve.utils.types import ModelServer
81+
from sagemaker.serve.utils.types import ModelServer, ModelHub
8282
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
8383
from sagemaker.serve.save_retrive.version_1_0_0.save.save_handler import SaveHandler
8484
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import get_metadata
@@ -400,7 +400,7 @@ def _prepare_for_mode(self):
400400
self.serve_settings.s3_model_data_url,
401401
self.sagemaker_session,
402402
self.image_uri,
403-
self.jumpstart if hasattr(self, "jumpstart") else False,
403+
getattr(self, "model_hub", None) == ModelHub.JUMPSTART,
404404
)
405405
self.env_vars.update(env_vars_sagemaker)
406406
return self.s3_upload_path, env_vars_sagemaker
@@ -754,10 +754,14 @@ def build( # pylint: disable=R0911
754754

755755
if isinstance(self.model, str):
756756
model_task = None
757-
if self.model_metadata:
758-
model_task = self.model_metadata.get("HF_TASK")
759757
if self._is_jumpstart_model_id():
758+
self.model_hub = ModelHub.JUMPSTART
760759
return self._build_for_jumpstart()
760+
self.model_hub = ModelHub.HUGGINGFACE
761+
762+
if self.model_metadata:
763+
model_task = self.model_metadata.get("HF_TASK")
764+
761765
if self._is_djl():
762766
return self._build_for_djl()
763767
else:

src/sagemaker/serve/builder/tei_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(self):
6363
self.nb_instance_type = None
6464
self.ram_usage_model_load = None
6565
self.secret_key = None
66-
self.jumpstart = None
6766
self.role_arn = None
6867

6968
@abstractmethod

src/sagemaker/serve/builder/tgi_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def __init__(self):
9090
self.nb_instance_type = None
9191
self.ram_usage_model_load = None
9292
self.secret_key = None
93-
self.jumpstart = None
9493
self.role_arn = None
9594

9695
@abstractmethod

src/sagemaker/serve/utils/telemetry_logger.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
MLFLOW_REGISTRY_PATH,
3030
)
3131
from sagemaker.serve.utils.lineage_utils import _get_mlflow_model_path_type
32-
from sagemaker.serve.utils.types import ModelServer, ImageUriOption
32+
from sagemaker.serve.utils.types import (
33+
ModelServer,
34+
ImageUriOption,
35+
ModelHub,
36+
SpeculativeDecodingDraftModelSource,
37+
)
3338
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
3439
from sagemaker.user_agent import SDK_VERSION
3540

@@ -69,6 +74,16 @@
6974
MLFLOW_REGISTRY_PATH: 5,
7075
}
7176

77+
MODEL_HUB_TO_CODE = {
78+
str(ModelHub.JUMPSTART): 1,
79+
str(ModelHub.HUGGINGFACE): 2,
80+
}
81+
82+
SD_DRAFT_MODEL_SOURCE_TO_CODE = {
83+
str(SpeculativeDecodingDraftModelSource.SAGEMAKER): 1,
84+
str(SpeculativeDecodingDraftModelSource.CUSTOM): 2,
85+
}
86+
7287

7388
def _capture_telemetry(func_name: str):
7489
"""Placeholder docstring"""
@@ -108,6 +123,28 @@ def wrapper(self, *args, **kwargs):
108123
mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path)
109124
extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}"
110125

126+
if getattr(self, "model_hub", False):
127+
extra += f"&x-modelHub={MODEL_HUB_TO_CODE[str(self.model_hub)]}"
128+
129+
if getattr(self, "is_fine_tuned", False):
130+
extra += "&x-fineTuned=1"
131+
if getattr(self, "is_gated", False):
132+
extra += "&x-gated=1"
133+
134+
if kwargs.get("compilation_config"):
135+
extra += "&x-compiled=1"
136+
if kwargs.get("quantization_config"):
137+
extra += "&x-quantized=1"
138+
if kwargs.get("speculative_decoding_config"):
139+
model_provider = kwargs["speculative_decoding_config"]["ModelProvider"]
140+
model_provider_enum = (
141+
SpeculativeDecodingDraftModelSource.SAGEMAKER
142+
if model_provider.lower() == "sagemaker"
143+
else SpeculativeDecodingDraftModelSource.CUSTOM
144+
)
145+
model_provider_value = SD_DRAFT_MODEL_SOURCE_TO_CODE[str(model_provider_enum)]
146+
extra += f"&x-sdDraftModelSource={model_provider_value}"
147+
111148
start_timer = perf_counter()
112149
try:
113150
response = func(self, *args, **kwargs)

src/sagemaker/serve/utils/types.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,25 @@ def __str__(self) -> str:
5757
CUSTOM_IMAGE = 1
5858
CUSTOM_1P_IMAGE = 2
5959
DEFAULT_IMAGE = 3
60+
61+
62+
class ModelHub(Enum):
63+
"""Enum type for model hub source"""
64+
65+
def __str__(self) -> str:
66+
"""Convert enum to string"""
67+
return str(self.name)
68+
69+
JUMPSTART = 1
70+
HUGGINGFACE = 2
71+
72+
73+
class SpeculativeDecodingDraftModelSource(Enum):
74+
"""Enum type for speculative decoding draft model source"""
75+
76+
def __str__(self) -> str:
77+
"""Convert enum to string"""
78+
return str(self.name)
79+
80+
SAGEMAKER = 1
81+
CUSTOM = 2

tests/unit/sagemaker/serve/utils/test_telemetry_logger.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
import unittest
15-
from unittest.mock import Mock, patch
15+
from unittest.mock import Mock, patch, MagicMock
1616
from sagemaker.serve import Mode, ModelServer
1717
from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH
1818
from sagemaker.serve.utils.telemetry_logger import (
@@ -25,7 +25,8 @@
2525
from sagemaker.user_agent import SDK_VERSION
2626

2727
MOCK_SESSION = Mock()
28-
MOCK_FUNC_NAME = "Mock.deploy"
28+
MOCK_DEPLOY_FUNC_NAME = "Mock.deploy"
29+
MOCK_OPTIMIZE_FUNC_NAME = "Mock.optimize"
2930
MOCK_DJL_CONTAINER = (
3031
"763104351884.dkr.ecr.us-west-2.amazonaws.com/" "djl-inference:0.25.0-deepspeed0.11.0-cu118"
3132
)
@@ -47,11 +48,15 @@ def __init__(self):
4748
self.serve_settings = Mock()
4849
self.sagemaker_session = MOCK_SESSION
4950

50-
@_capture_telemetry(MOCK_FUNC_NAME)
51+
@_capture_telemetry(MOCK_DEPLOY_FUNC_NAME)
5152
def mock_deploy(self, mock_exception_func=None):
5253
if mock_exception_func:
5354
mock_exception_func()
5455

56+
@_capture_telemetry(MOCK_OPTIMIZE_FUNC_NAME)
57+
def mock_optimize(self, *args, **kwargs):
58+
pass
59+
5560

5661
class TestTelemetryLogger(unittest.TestCase):
5762
@patch("sagemaker.serve.utils.telemetry_logger._requests_helper")
@@ -88,7 +93,7 @@ def test_capture_telemetry_decorator_djl_success(self, mock_send_telemetry):
8893
args = mock_send_telemetry.call_args.args
8994
latency = str(args[5]).split("latency=")[1]
9095
expected_extra_str = (
91-
f"{MOCK_FUNC_NAME}"
96+
f"{MOCK_DEPLOY_FUNC_NAME}"
9297
"&x-modelServer=4"
9398
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
9499
f"&x-sdkVersion={SDK_VERSION}"
@@ -118,7 +123,7 @@ def test_capture_telemetry_decorator_djl_success_with_custom_image(self, mock_se
118123
args = mock_send_telemetry.call_args.args
119124
latency = str(args[5]).split("latency=")[1]
120125
expected_extra_str = (
121-
f"{MOCK_FUNC_NAME}"
126+
f"{MOCK_DEPLOY_FUNC_NAME}"
122127
"&x-modelServer=4"
123128
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
124129
f"&x-sdkVersion={SDK_VERSION}"
@@ -148,7 +153,7 @@ def test_capture_telemetry_decorator_tgi_success(self, mock_send_telemetry):
148153
args = mock_send_telemetry.call_args.args
149154
latency = str(args[5]).split("latency=")[1]
150155
expected_extra_str = (
151-
f"{MOCK_FUNC_NAME}"
156+
f"{MOCK_DEPLOY_FUNC_NAME}"
152157
"&x-modelServer=6"
153158
"&x-imageTag=huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
154159
f"&x-sdkVersion={SDK_VERSION}"
@@ -196,7 +201,7 @@ def test_capture_telemetry_decorator_handle_exception_success(self, mock_send_te
196201
args = mock_send_telemetry.call_args.args
197202
latency = str(args[5]).split("latency=")[1]
198203
expected_extra_str = (
199-
f"{MOCK_FUNC_NAME}"
204+
f"{MOCK_DEPLOY_FUNC_NAME}"
200205
"&x-modelServer=4"
201206
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
202207
f"&x-sdkVersion={SDK_VERSION}"
@@ -243,7 +248,7 @@ def test_construct_url_with_failure_reason_and_extra_info(self):
243248
f"&x-failureType={mock_failure_type}"
244249
f"&x-extra={mock_extra_info}"
245250
)
246-
self.assertEquals(ret_url, expected_base_url)
251+
self.assertEqual(ret_url, expected_base_url)
247252

248253
@patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
249254
def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry):
@@ -262,7 +267,7 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry):
262267
args = mock_send_telemetry.call_args.args
263268
latency = str(args[5]).split("latency=")[1]
264269
expected_extra_str = (
265-
f"{MOCK_FUNC_NAME}"
270+
f"{MOCK_DEPLOY_FUNC_NAME}"
266271
"&x-modelServer=1"
267272
"&x-imageTag=pytorch-inference:2.0.1-cpu-py310"
268273
f"&x-sdkVersion={SDK_VERSION}"
@@ -275,3 +280,66 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry):
275280
mock_send_telemetry.assert_called_once_with(
276281
"1", 3, MOCK_SESSION, None, None, expected_extra_str
277282
)
283+
284+
@patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
285+
def test_capture_telemetry_decorator_optimize_with_default_configs(self, mock_send_telemetry):
286+
mock_model_builder = ModelBuilderMock()
287+
mock_model_builder.serve_settings.telemetry_opt_out = False
288+
mock_model_builder.image_uri = None
289+
mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT
290+
mock_model_builder.model_server = ModelServer.TORCHSERVE
291+
mock_model_builder.sagemaker_session.endpoint_arn = None
292+
293+
mock_model_builder.mock_optimize()
294+
295+
args = mock_send_telemetry.call_args.args
296+
latency = str(args[5]).split("latency=")[1]
297+
expected_extra_str = (
298+
f"{MOCK_OPTIMIZE_FUNC_NAME}"
299+
"&x-modelServer=1"
300+
f"&x-sdkVersion={SDK_VERSION}"
301+
f"&x-latency={latency}"
302+
)
303+
304+
mock_send_telemetry.assert_called_once_with(
305+
"1", 3, MOCK_SESSION, None, None, expected_extra_str
306+
)
307+
308+
@patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
309+
def test_capture_telemetry_decorator_optimize_with_custom_configs(self, mock_send_telemetry):
310+
mock_model_builder = ModelBuilderMock()
311+
mock_model_builder.serve_settings.telemetry_opt_out = False
312+
mock_model_builder.image_uri = None
313+
mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT
314+
mock_model_builder.model_server = ModelServer.TORCHSERVE
315+
mock_model_builder.sagemaker_session.endpoint_arn = None
316+
mock_model_builder.is_fine_tuned = True
317+
mock_model_builder.is_gated = True
318+
319+
mock_speculative_decoding_config = MagicMock()
320+
mock_config = {"ModelProvider": "sagemaker"}
321+
mock_speculative_decoding_config.__getitem__.side_effect = mock_config.__getitem__
322+
323+
mock_model_builder.mock_optimize(
324+
quantization_config=Mock(),
325+
compilation_config=Mock(),
326+
speculative_decoding_config=mock_speculative_decoding_config,
327+
)
328+
329+
args = mock_send_telemetry.call_args.args
330+
latency = str(args[5]).split("latency=")[1]
331+
expected_extra_str = (
332+
f"{MOCK_OPTIMIZE_FUNC_NAME}"
333+
"&x-modelServer=1"
334+
f"&x-sdkVersion={SDK_VERSION}"
335+
f"&x-fineTuned=1"
336+
f"&x-gated=1"
337+
f"&x-compiled=1"
338+
f"&x-quantized=1"
339+
f"&x-sdDraftModelSource=1"
340+
f"&x-latency={latency}"
341+
)
342+
343+
mock_send_telemetry.assert_called_once_with(
344+
"1", 3, MOCK_SESSION, None, None, expected_extra_str
345+
)

0 commit comments

Comments
 (0)