Skip to content

Commit 24c21bd

Browse files
author
Uemit Yoldas
committed
feature: integrate amtviz for visualization of tuning jobs
1 parent 18897d7 commit 24c21bd

File tree

7 files changed

+1473
-0
lines changed

7 files changed

+1473
-0
lines changed

src/sagemaker/amtviz/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: MIT-0
3+
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this
5+
# software and associated documentation files (the "Software"), to deal in the Software
6+
# without restriction, including without limitation the rights to use, copy, modify,
7+
# merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
8+
# permit persons to whom the Software is furnished to do so.
9+
10+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
11+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
12+
# PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
13+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
14+
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15+
16+
from sagemaker.amtviz.visualization import visualize_tuning_job
17+
__all__ = ['visualize_tuning_job']

src/sagemaker/amtviz/job_metrics.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: MIT-0
3+
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy
5+
# of this software and associated documentation files (the "Software"), to deal
6+
# in the Software without restriction, including without limitation the rights
7+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
# copies of the Software, and to permit persons to whom the Software is
9+
# furnished to do so.
10+
11+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14+
# AUTHORS OR COPYRIGHT OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
15+
# IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16+
17+
from datetime import datetime, timedelta
18+
from typing import Callable, List, Optional, Tuple, Dict, Any
19+
import hashlib
20+
import os
21+
from pathlib import Path
22+
23+
import pandas as pd
24+
import numpy as np
25+
import boto3
26+
import logging
27+
28+
logger = logging.getLogger(__name__)
29+
30+
cw = boto3.client("cloudwatch")
31+
sm = boto3.client("sagemaker")
32+
33+
34+
def disk_cache(outer: Callable) -> Callable:
35+
"""A decorator that implements disk-based caching for CloudWatch metrics data.
36+
37+
This decorator caches the output of the wrapped function to disk in JSON Lines format.
38+
It creates a cache key using MD5 hash of the function arguments and stores the data
39+
in the user's home directory under .amtviz/cw_metrics_cache/.
40+
41+
Args:
42+
outer (Callable): The function to be wrapped. Must return a pandas DataFrame
43+
containing CloudWatch metrics data.
44+
45+
Returns:
46+
Callable: A wrapper function that implements the caching logic.
47+
"""
48+
49+
def inner(*args: Any, **kwargs: Any) -> pd.DataFrame:
50+
key_input = str(args) + str(kwargs)
51+
# nosec b303 - Not used for cryptography, but to create lookup key
52+
key = hashlib.md5(key_input.encode("utf-8")).hexdigest()
53+
cache_dir = Path.home().joinpath(".amtviz/cw_metrics_cache")
54+
fn = f"{cache_dir}/req_{key}.jsonl.gz"
55+
if Path(fn).exists():
56+
try:
57+
df = pd.read_json(fn, lines=True)
58+
logger.debug("H", end="")
59+
df["ts"] = pd.to_datetime(df["ts"])
60+
df["ts"] = df["ts"].dt.tz_localize(None)
61+
df["rel_ts"] = pd.to_datetime(df["rel_ts"]) # pyright: ignore [reportIndexIssue, reportOptionalSubscript]
62+
df["rel_ts"] = df["rel_ts"].dt.tz_localize(None)
63+
return df
64+
except KeyError:
65+
# Empty file leads to empty df, hence no df['ts'] possible
66+
pass
67+
# nosec b110 - doesn't matter why we could not load it.
68+
except BaseException as e:
69+
logger.error("\nException", type(e), e)
70+
pass # continue with calling the outer function
71+
72+
logger.debug("M", end="")
73+
df = outer(*args, **kwargs)
74+
assert isinstance(df, pd.DataFrame), "Only caching Pandas DataFrames."
75+
76+
os.makedirs(cache_dir, exist_ok=True)
77+
df.to_json(fn, orient="records", date_format="iso", lines=True)
78+
79+
return df
80+
81+
return inner
82+
83+
84+
def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> Dict[str, Any]:
85+
return {
86+
"Id": metric_name.lower().replace(":", "_").replace("-", "_"),
87+
"MetricStat": {
88+
"Stat": "Average",
89+
"Metric": {
90+
"Namespace": "/aws/sagemaker/TrainingJobs",
91+
"MetricName": metric_name,
92+
"Dimensions": [
93+
{"Name": dim_name, "Value": dim_value},
94+
],
95+
},
96+
"Period": 60,
97+
},
98+
"ReturnData": True,
99+
}
100+
101+
102+
def _get_metric_data(
103+
queries: List[Dict[str, Any]],
104+
start_time: datetime,
105+
end_time: datetime
106+
) -> pd.DataFrame:
107+
start_time = start_time - timedelta(hours=1)
108+
end_time = end_time + timedelta(hours=1)
109+
response = cw.get_metric_data(MetricDataQueries=queries, StartTime=start_time, EndTime=end_time)
110+
111+
df = pd.DataFrame()
112+
if "MetricDataResults" not in response:
113+
return df
114+
115+
for metric_data in response["MetricDataResults"]:
116+
values = metric_data["Values"]
117+
ts = np.array(metric_data["Timestamps"], dtype=np.datetime64)
118+
labels = [metric_data["Label"]] * len(values)
119+
120+
df = pd.concat([df, pd.DataFrame({"value": values, "ts": ts, "label": labels})])
121+
122+
# We now calculate the relative time based on the first actual observed
123+
# time stamps, not the potentially start time that we used to scope our CW
124+
# API call. The difference could be for example startup times or waiting
125+
# for Spot.
126+
if not df.empty:
127+
df["rel_ts"] = datetime.fromtimestamp(1) + (df["ts"] - df["ts"].min()) # pyright: ignore
128+
return df
129+
130+
131+
@disk_cache
132+
def _collect_metrics(
133+
dimensions: List[Tuple[str, str]],
134+
start_time: datetime,
135+
end_time: Optional[datetime]
136+
) -> pd.DataFrame:
137+
138+
df = pd.DataFrame()
139+
for dim_name, dim_value in dimensions:
140+
response = cw.list_metrics(
141+
Namespace="/aws/sagemaker/TrainingJobs",
142+
Dimensions=[
143+
{"Name": dim_name, "Value": dim_value},
144+
],
145+
)
146+
if not response["Metrics"]:
147+
continue
148+
metric_names = [metric["MetricName"] for metric in response["Metrics"]]
149+
if not metric_names:
150+
# No metric data yet, or not any longer, because the data were aged out
151+
continue
152+
metric_data_queries = [
153+
_metric_data_query_tpl(metric_name, dim_name, dim_value) for metric_name in metric_names
154+
]
155+
df = pd.concat([df, _get_metric_data(metric_data_queries, start_time, end_time)])
156+
157+
return df
158+
159+
160+
def get_cw_job_metrics(
161+
job_name: str,
162+
start_time: Optional[datetime] = None,
163+
end_time: Optional[datetime] = None
164+
) -> pd.DataFrame:
165+
"""Retrieves CloudWatch metrics for a SageMaker training job.
166+
167+
Args:
168+
job_name (str): Name of the SageMaker training job.
169+
start_time (datetime, optional): Start time for metrics collection.
170+
Defaults to now - 4 hours.
171+
end_time (datetime, optional): End time for metrics collection.
172+
Defaults to start_time + 4 hours.
173+
174+
Returns:
175+
pd.DataFrame: Metrics data with columns for value, timestamp, and metric name.
176+
Results are cached to disk for improved performance.
177+
"""
178+
dimensions = [
179+
("TrainingJobName", job_name),
180+
("Host", job_name + "/algo-1"),
181+
]
182+
# If not given, use reasonable defaults for start and end time
183+
start_time = start_time or datetime.now() - timedelta(hours=4)
184+
end_time = end_time or start_time + timedelta(hours=4)
185+
return _collect_metrics(dimensions, start_time, end_time)

0 commit comments

Comments
 (0)