29
29
# Any subsequent attempt to use pandas will raise the ImportError
30
30
pd = DeferredError (e )
31
31
32
+ METRICS_PERIOD_DEFAULT = 60 # seconds
33
+
32
34
33
35
class AnalyticsMetricsBase (with_metaclass (ABCMeta , object )):
34
36
"""Base class for tuning job or training job analytics classes.
@@ -200,7 +202,8 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):
200
202
201
203
CLOUDWATCH_NAMESPACE = '/aws/sagemaker/TrainingJobs'
202
204
203
- def __init__ (self , training_job_name , metric_names = None , sagemaker_session = None ):
205
+ def __init__ (self , training_job_name , metric_names = None , sagemaker_session = None ,
206
+ start_time = None , end_time = None , period = None ):
204
207
"""Initialize a ``TrainingJobAnalytics`` instance.
205
208
206
209
Args:
@@ -215,6 +218,10 @@ def __init__(self, training_job_name, metric_names=None, sagemaker_session=None)
215
218
self ._sage_client = sagemaker_session .sagemaker_client
216
219
self ._cloudwatch = sagemaker_session .boto_session .client ('cloudwatch' )
217
220
self ._training_job_name = training_job_name
221
+ self ._start_time = start_time
222
+ self ._end_time = end_time
223
+ self ._period = period or METRICS_PERIOD_DEFAULT
224
+
218
225
if metric_names :
219
226
self ._metric_names = metric_names
220
227
else :
@@ -244,13 +251,15 @@ def _determine_timeinterval(self):
244
251
covering the interval of the training job
245
252
"""
246
253
description = self ._sage_client .describe_training_job (TrainingJobName = self .name )
247
- start_time = description [u'TrainingStartTime' ] # datetime object
254
+ start_time = self . _start_time or description [u'TrainingStartTime' ] # datetime object
248
255
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
249
256
# This results in logs being searched in the time range in which the correct log line was not present.
250
257
# Example - Log time - 2018-10-22 08:25:55
251
258
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
252
259
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log.
253
- end_time = description .get (u'TrainingEndTime' , datetime .datetime .utcnow ()) + datetime .timedelta (minutes = 1 )
260
+ end_time = self ._end_time or description .get (
261
+ u'TrainingEndTime' , datetime .datetime .utcnow ()) + datetime .timedelta (minutes = 1 )
262
+
254
263
return {
255
264
'start_time' : start_time ,
256
265
'end_time' : end_time ,
@@ -275,7 +284,7 @@ def _fetch_metric(self, metric_name):
275
284
],
276
285
'StartTime' : self ._time_interval ['start_time' ],
277
286
'EndTime' : self ._time_interval ['end_time' ],
278
- 'Period' : 60 ,
287
+ 'Period' : self . _period ,
279
288
'Statistics' : ['Average' ],
280
289
}
281
290
raw_cwm_data = self ._cloudwatch .get_metric_statistics (** request )['Datapoints' ]
0 commit comments