Skip to content

Commit 23fad5f

Browse files
authored
change: Enable telemetry logging for Remote function (aws#4729)
* change: Enhance telemetry logging module and feature coverage * Fix default session issue * fix unit-tests
1 parent 237456b commit 23fad5f

File tree

5 files changed

+127
-79
lines changed

5 files changed

+127
-79
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from sagemaker.utils import name_from_base, base_from_name
4141
from sagemaker.remote_function.spark_config import SparkConfig
4242
from sagemaker.remote_function.custom_file_filter import CustomFileFilter
43+
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
44+
from sagemaker.telemetry.constants import Feature
4345

4446
_API_CALL_LIMIT = {
4547
"SubmittingIntervalInSecs": 1,
@@ -57,6 +59,7 @@
5759
logger = logging_config.get_logger()
5860

5961

62+
@_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote")
6063
def remote(
6164
_func=None,
6265
*,

src/sagemaker/telemetry/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Feature(Enum):
2424

2525
SDK_DEFAULTS = 1
2626
LOCAL_MODE = 2
27+
REMOTE_FUNCTION = 3
2728

2829
def __str__(self): # pylint: disable=E0307
2930
"""Return the feature name."""

src/sagemaker/telemetry/telemetry_logging.py

Lines changed: 102 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
import sys
1818
from time import perf_counter
1919
from typing import List
20+
import functools
21+
import requests
2022

23+
import boto3
24+
from sagemaker.session import Session
2125
from sagemaker.utils import resolve_value_from_config
2226
from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH
2327
from sagemaker.telemetry.constants import (
@@ -47,6 +51,7 @@
4751
FEATURE_TO_CODE = {
4852
str(Feature.SDK_DEFAULTS): 1,
4953
str(Feature.LOCAL_MODE): 2,
54+
str(Feature.REMOTE_FUNCTION): 3,
5055
}
5156

5257
STATUS_TO_CODE = {
@@ -59,86 +64,103 @@ def _telemetry_emitter(feature: str, func_name: str):
5964
"""Decorator to emit telemetry logs for SageMaker Python SDK functions"""
6065

6166
def decorator(func):
62-
def wrapper(self, *args, **kwargs):
63-
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
64-
response = None
65-
caught_ex = None
66-
studio_app_type = process_studio_metadata_file()
67-
68-
# Check if telemetry is opted out
69-
telemetry_opt_out_flag = resolve_value_from_config(
70-
direct_input=None,
71-
config_path=TELEMETRY_OPT_OUT_PATH,
72-
default_value=False,
73-
sagemaker_session=self.sagemaker_session,
74-
)
75-
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)
76-
77-
# Construct the feature list to track feature combinations
78-
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]
79-
if self.sagemaker_session:
80-
if self.sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
67+
@functools.wraps(func)
68+
def wrapper(*args, **kwargs):
69+
sagemaker_session = None
70+
if len(args) > 0 and hasattr(args[0], "sagemaker_session"):
71+
# Get the sagemaker_session from the instance method args
72+
sagemaker_session = args[0].sagemaker_session
73+
elif feature == Feature.REMOTE_FUNCTION:
74+
# Get the sagemaker_session from the function keyword arguments for remote function
75+
sagemaker_session = kwargs.get(
76+
"sagemaker_session", _get_default_sagemaker_session()
77+
)
78+
79+
if sagemaker_session:
80+
logger.debug("sagemaker_session found, preparing to emit telemetry...")
81+
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
82+
response = None
83+
caught_ex = None
84+
studio_app_type = process_studio_metadata_file()
85+
86+
# Check if telemetry is opted out
87+
telemetry_opt_out_flag = resolve_value_from_config(
88+
direct_input=None,
89+
config_path=TELEMETRY_OPT_OUT_PATH,
90+
default_value=False,
91+
sagemaker_session=sagemaker_session,
92+
)
93+
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)
94+
95+
# Construct the feature list to track feature combinations
96+
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]
97+
98+
if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
8199
feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)])
82100

83-
if self.sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
101+
if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
84102
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])
85103

86-
# Construct the extra info to track platform and environment usage metadata
87-
extra = (
88-
f"{func_name}"
89-
f"&x-sdkVersion={SDK_VERSION}"
90-
f"&x-env={PYTHON_VERSION}"
91-
f"&x-sys={OS_NAME_VERSION}"
92-
f"&x-platform={studio_app_type}"
93-
)
94-
95-
# Add endpoint ARN to the extra info if available
96-
if self.sagemaker_session and self.sagemaker_session.endpoint_arn:
97-
extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}"
98-
99-
start_timer = perf_counter()
100-
try:
101-
# Call the original function
102-
response = func(self, *args, **kwargs)
103-
stop_timer = perf_counter()
104-
elapsed = stop_timer - start_timer
105-
extra += f"&x-latency={round(elapsed, 2)}"
106-
if not telemetry_opt_out_flag:
107-
_send_telemetry_request(
108-
STATUS_TO_CODE[str(Status.SUCCESS)],
109-
feature_list,
110-
self.sagemaker_session,
111-
None,
112-
None,
113-
extra,
114-
)
115-
except Exception as e: # pylint: disable=W0703
116-
stop_timer = perf_counter()
117-
elapsed = stop_timer - start_timer
118-
extra += f"&x-latency={round(elapsed, 2)}"
119-
if not telemetry_opt_out_flag:
120-
_send_telemetry_request(
121-
STATUS_TO_CODE[str(Status.FAILURE)],
122-
feature_list,
123-
self.sagemaker_session,
124-
str(e),
125-
e.__class__.__name__,
126-
extra,
127-
)
128-
caught_ex = e
129-
finally:
130-
if caught_ex:
131-
raise caught_ex
132-
return response # pylint: disable=W0150
104+
# Construct the extra info to track platform and environment usage metadata
105+
extra = (
106+
f"{func_name}"
107+
f"&x-sdkVersion={SDK_VERSION}"
108+
f"&x-env={PYTHON_VERSION}"
109+
f"&x-sys={OS_NAME_VERSION}"
110+
f"&x-platform={studio_app_type}"
111+
)
112+
113+
# Add endpoint ARN to the extra info if available
114+
if sagemaker_session.endpoint_arn:
115+
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"
116+
117+
start_timer = perf_counter()
118+
try:
119+
# Call the original function
120+
response = func(*args, **kwargs)
121+
stop_timer = perf_counter()
122+
elapsed = stop_timer - start_timer
123+
extra += f"&x-latency={round(elapsed, 2)}"
124+
if not telemetry_opt_out_flag:
125+
_send_telemetry_request(
126+
STATUS_TO_CODE[str(Status.SUCCESS)],
127+
feature_list,
128+
sagemaker_session,
129+
None,
130+
None,
131+
extra,
132+
)
133+
except Exception as e: # pylint: disable=W0703
134+
stop_timer = perf_counter()
135+
elapsed = stop_timer - start_timer
136+
extra += f"&x-latency={round(elapsed, 2)}"
137+
if not telemetry_opt_out_flag:
138+
_send_telemetry_request(
139+
STATUS_TO_CODE[str(Status.FAILURE)],
140+
feature_list,
141+
sagemaker_session,
142+
str(e),
143+
e.__class__.__name__,
144+
extra,
145+
)
146+
caught_ex = e
147+
finally:
148+
if caught_ex:
149+
raise caught_ex
150+
return response # pylint: disable=W0150
151+
else:
152+
logger.debug(
153+
"Unable to send telemetry for function %s. "
154+
"sagemaker_session is not provided or not valid.",
155+
func_name,
156+
)
157+
return func(*args, **kwargs)
133158

134159
return wrapper
135160

136161
return decorator
137162

138163

139-
from sagemaker.session import Session # noqa: E402 pylint: disable=C0413
140-
141-
142164
def _send_telemetry_request(
143165
status: int,
144166
feature_list: List[int],
@@ -165,9 +187,9 @@ def _send_telemetry_request(
165187
# Send the telemetry request
166188
logger.debug("Sending telemetry request to [%s]", url)
167189
_requests_helper(url, 2)
168-
logger.debug("SageMaker Python SDK telemetry successfully emitted!")
190+
logger.debug("SageMaker Python SDK telemetry successfully emitted.")
169191
except Exception: # pylint: disable=W0703
170-
logger.debug("SageMaker Python SDK telemetry not emitted!!")
192+
logger.debug("SageMaker Python SDK telemetry not emitted!")
171193

172194

173195
def _construct_url(
@@ -196,9 +218,6 @@ def _construct_url(
196218
return base_url
197219

198220

199-
import requests # noqa: E402 pylint: disable=C0413,C0411
200-
201-
202221
def _requests_helper(url, timeout):
203222
"""Make a GET request to the given URL"""
204223

@@ -227,3 +246,11 @@ def _get_region_or_default(session):
227246
return session.boto_session.region_name
228247
except Exception: # pylint: disable=W0703
229248
return DEFAULT_AWS_REGION
249+
250+
251+
def _get_default_sagemaker_session():
252+
"""Return the default sagemaker session"""
253+
boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION)
254+
sagemaker_session = Session(boto_session=boto_session)
255+
256+
return sagemaker_session

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import threading
1717
import time
18+
import inspect
1819

1920
import pytest
2021
from mock import MagicMock, patch, Mock, ANY, call
@@ -1498,7 +1499,6 @@ def test_consistency_between_remote_and_step_decorator():
14981499
from sagemaker.workflow.function_step import step
14991500

15001501
remote_args_to_ignore = [
1501-
"_remote",
15021502
"include_local_workdir",
15031503
"custom_file_filter",
15041504
"s3_kms_key",
@@ -1508,7 +1508,7 @@ def test_consistency_between_remote_and_step_decorator():
15081508

15091509
step_args_to_ignore = ["_step", "name", "display_name", "description", "retry_policies"]
15101510

1511-
remote_decorator_args = remote.__code__.co_varnames
1511+
remote_decorator_args = inspect.signature(remote).parameters.keys()
15121512
common_remote_decorator_args = set(remote_args_to_ignore) ^ set(remote_decorator_args)
15131513

15141514
step_decorator_args = step.__code__.co_varnames
@@ -1522,8 +1522,7 @@ def test_consistency_between_remote_and_executor():
15221522
executor_arg_list.remove("self")
15231523
executor_arg_list.remove("max_parallel_jobs")
15241524

1525-
remote_args_list = list(remote.__code__.co_varnames)
1526-
remote_args_list.remove("_remote")
1525+
remote_args_list = list(inspect.signature(remote).parameters.keys())
15271526
remote_args_list.remove("_func")
15281527

15291528
assert executor_arg_list == remote_args_list

tests/unit/sagemaker/telemetry/test_telemetry_logging.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616
import requests
1717
from unittest.mock import Mock, patch, MagicMock
18+
import boto3
1819
import sagemaker
1920
from sagemaker.telemetry.constants import Feature
2021
from sagemaker.telemetry.telemetry_logging import (
@@ -24,6 +25,7 @@
2425
_get_accountId,
2526
_requests_helper,
2627
_get_region_or_default,
28+
_get_default_sagemaker_session,
2729
OS_NAME_VERSION,
2830
PYTHON_VERSION,
2931
)
@@ -282,3 +284,19 @@ def test_get_region_or_default_exception(self):
282284
region = _get_region_or_default(mock_session)
283285
assert region == "us-west-2"
284286
assert "Error creating boto session" in str(exception)
287+
288+
@patch.object(boto3.Session, "region_name", "us-west-2")
289+
def test_get_default_sagemaker_session(self):
290+
sagemaker_session = _get_default_sagemaker_session()
291+
292+
assert isinstance(sagemaker_session, sagemaker.Session) is True
293+
assert sagemaker_session.boto_session.region_name == "us-west-2"
294+
295+
@patch.object(boto3.Session, "region_name", None)
296+
def test_get_default_sagemaker_session_with_no_region(self):
297+
with self.assertRaises(ValueError) as context:
298+
_get_default_sagemaker_session()
299+
300+
assert "Must setup local AWS configuration with a region supported by SageMaker." in str(
301+
context.exception
302+
)

0 commit comments

Comments
 (0)