1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: MIT-0
3
+
4
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ # of this software and associated documentation files (the "Software"), to deal
6
+ # in the Software without restriction, including without limitation the rights
7
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8
+ # copies of the Software, and to permit persons to whom the Software is
9
+ # furnished to do so.
10
+
11
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14
+ # AUTHORS OR COPYRIGHT OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
15
+ # IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+
17
+ from datetime import datetime , timedelta
18
+ from typing import Callable , List , Optional , Tuple , Dict , Any
19
+ import hashlib
20
+ import os
21
+ from pathlib import Path
22
+
23
+ import pandas as pd
24
+ import numpy as np
25
+ import boto3
26
+ import logging
27
+
28
+ logger = logging .getLogger (__name__ )
29
+
30
+ cw = boto3 .client ("cloudwatch" )
31
+ sm = boto3 .client ("sagemaker" )
32
+
33
+
34
+ def disk_cache (outer : Callable ) -> Callable :
35
+ """A decorator that implements disk-based caching for CloudWatch metrics data.
36
+
37
+ This decorator caches the output of the wrapped function to disk in JSON Lines format.
38
+ It creates a cache key using MD5 hash of the function arguments and stores the data
39
+ in the user's home directory under .amtviz/cw_metrics_cache/.
40
+
41
+ Args:
42
+ outer (Callable): The function to be wrapped. Must return a pandas DataFrame
43
+ containing CloudWatch metrics data.
44
+
45
+ Returns:
46
+ Callable: A wrapper function that implements the caching logic.
47
+ """
48
+
49
+ def inner (* args : Any , ** kwargs : Any ) -> pd .DataFrame :
50
+ key_input = str (args ) + str (kwargs )
51
+ # nosec b303 - Not used for cryptography, but to create lookup key
52
+ key = hashlib .md5 (key_input .encode ("utf-8" )).hexdigest ()
53
+ cache_dir = Path .home ().joinpath (".amtviz/cw_metrics_cache" )
54
+ fn = f"{ cache_dir } /req_{ key } .jsonl.gz"
55
+ if Path (fn ).exists ():
56
+ try :
57
+ df = pd .read_json (fn , lines = True )
58
+ logger .debug ("H" , end = "" )
59
+ df ["ts" ] = pd .to_datetime (df ["ts" ])
60
+ df ["ts" ] = df ["ts" ].dt .tz_localize (None )
61
+ df ["rel_ts" ] = pd .to_datetime (df ["rel_ts" ]) # pyright: ignore [reportIndexIssue, reportOptionalSubscript]
62
+ df ["rel_ts" ] = df ["rel_ts" ].dt .tz_localize (None )
63
+ return df
64
+ except KeyError :
65
+ # Empty file leads to empty df, hence no df['ts'] possible
66
+ pass
67
+ # nosec b110 - doesn't matter why we could not load it.
68
+ except BaseException as e :
69
+ logger .error ("\n Exception" , type (e ), e )
70
+ pass # continue with calling the outer function
71
+
72
+ logger .debug ("M" , end = "" )
73
+ df = outer (* args , ** kwargs )
74
+ assert isinstance (df , pd .DataFrame ), "Only caching Pandas DataFrames."
75
+
76
+ os .makedirs (cache_dir , exist_ok = True )
77
+ df .to_json (fn , orient = "records" , date_format = "iso" , lines = True )
78
+
79
+ return df
80
+
81
+ return inner
82
+
83
+
84
+ def _metric_data_query_tpl (metric_name : str , dim_name : str , dim_value : str ) -> Dict [str , Any ]:
85
+ return {
86
+ "Id" : metric_name .lower ().replace (":" , "_" ).replace ("-" , "_" ),
87
+ "MetricStat" : {
88
+ "Stat" : "Average" ,
89
+ "Metric" : {
90
+ "Namespace" : "/aws/sagemaker/TrainingJobs" ,
91
+ "MetricName" : metric_name ,
92
+ "Dimensions" : [
93
+ {"Name" : dim_name , "Value" : dim_value },
94
+ ],
95
+ },
96
+ "Period" : 60 ,
97
+ },
98
+ "ReturnData" : True ,
99
+ }
100
+
101
+
102
+ def _get_metric_data (
103
+ queries : List [Dict [str , Any ]],
104
+ start_time : datetime ,
105
+ end_time : datetime
106
+ ) -> pd .DataFrame :
107
+ start_time = start_time - timedelta (hours = 1 )
108
+ end_time = end_time + timedelta (hours = 1 )
109
+ response = cw .get_metric_data (MetricDataQueries = queries , StartTime = start_time , EndTime = end_time )
110
+
111
+ df = pd .DataFrame ()
112
+ if "MetricDataResults" not in response :
113
+ return df
114
+
115
+ for metric_data in response ["MetricDataResults" ]:
116
+ values = metric_data ["Values" ]
117
+ ts = np .array (metric_data ["Timestamps" ], dtype = np .datetime64 )
118
+ labels = [metric_data ["Label" ]] * len (values )
119
+
120
+ df = pd .concat ([df , pd .DataFrame ({"value" : values , "ts" : ts , "label" : labels })])
121
+
122
+ # We now calculate the relative time based on the first actual observed
123
+ # time stamps, not the potentially start time that we used to scope our CW
124
+ # API call. The difference could be for example startup times or waiting
125
+ # for Spot.
126
+ if not df .empty :
127
+ df ["rel_ts" ] = datetime .fromtimestamp (1 ) + (df ["ts" ] - df ["ts" ].min ()) # pyright: ignore
128
+ return df
129
+
130
+
131
+ @disk_cache
132
+ def _collect_metrics (
133
+ dimensions : List [Tuple [str , str ]],
134
+ start_time : datetime ,
135
+ end_time : Optional [datetime ]
136
+ ) -> pd .DataFrame :
137
+
138
+ df = pd .DataFrame ()
139
+ for dim_name , dim_value in dimensions :
140
+ response = cw .list_metrics (
141
+ Namespace = "/aws/sagemaker/TrainingJobs" ,
142
+ Dimensions = [
143
+ {"Name" : dim_name , "Value" : dim_value },
144
+ ],
145
+ )
146
+ if not response ["Metrics" ]:
147
+ continue
148
+ metric_names = [metric ["MetricName" ] for metric in response ["Metrics" ]]
149
+ if not metric_names :
150
+ # No metric data yet, or not any longer, because the data were aged out
151
+ continue
152
+ metric_data_queries = [
153
+ _metric_data_query_tpl (metric_name , dim_name , dim_value ) for metric_name in metric_names
154
+ ]
155
+ df = pd .concat ([df , _get_metric_data (metric_data_queries , start_time , end_time )])
156
+
157
+ return df
158
+
159
+
160
+ def get_cw_job_metrics (
161
+ job_name : str ,
162
+ start_time : Optional [datetime ] = None ,
163
+ end_time : Optional [datetime ] = None
164
+ ) -> pd .DataFrame :
165
+ """Retrieves CloudWatch metrics for a SageMaker training job.
166
+
167
+ Args:
168
+ job_name (str): Name of the SageMaker training job.
169
+ start_time (datetime, optional): Start time for metrics collection.
170
+ Defaults to now - 4 hours.
171
+ end_time (datetime, optional): End time for metrics collection.
172
+ Defaults to start_time + 4 hours.
173
+
174
+ Returns:
175
+ pd.DataFrame: Metrics data with columns for value, timestamp, and metric name.
176
+ Results are cached to disk for improved performance.
177
+ """
178
+ dimensions = [
179
+ ("TrainingJobName" , job_name ),
180
+ ("Host" , job_name + "/algo-1" ),
181
+ ]
182
+ # If not given, use reasonable defaults for start and end time
183
+ start_time = start_time or datetime .now () - timedelta (hours = 4 )
184
+ end_time = end_time or start_time + timedelta (hours = 4 )
185
+ return _collect_metrics (dimensions , start_time , end_time )
0 commit comments