Skip to content

Commit 466f37e

Browse files
authored
Merge branch 'master' into fix-tags
2 parents a5e1f53 + b15c05a commit 466f37e

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

src/sagemaker/analytics.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
# Any subsequent attempt to use pandas will raise the ImportError
3030
pd = DeferredError(e)
3131

32+
METRICS_PERIOD_DEFAULT = 60 # seconds
33+
3234

3335
class AnalyticsMetricsBase(with_metaclass(ABCMeta, object)):
3436
"""Base class for tuning job or training job analytics classes.
@@ -200,7 +202,8 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):
200202

201203
CLOUDWATCH_NAMESPACE = '/aws/sagemaker/TrainingJobs'
202204

203-
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None):
205+
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None,
206+
start_time=None, end_time=None, period=None):
204207
"""Initialize a ``TrainingJobAnalytics`` instance.
205208
206209
Args:
@@ -215,6 +218,10 @@ def __init__(self, training_job_name, metric_names=None, sagemaker_session=None)
215218
self._sage_client = sagemaker_session.sagemaker_client
216219
self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch')
217220
self._training_job_name = training_job_name
221+
self._start_time = start_time
222+
self._end_time = end_time
223+
self._period = period or METRICS_PERIOD_DEFAULT
224+
218225
if metric_names:
219226
self._metric_names = metric_names
220227
else:
@@ -244,13 +251,15 @@ def _determine_timeinterval(self):
244251
covering the interval of the training job
245252
"""
246253
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
247-
start_time = description[u'TrainingStartTime'] # datetime object
254+
start_time = self._start_time or description[u'TrainingStartTime'] # datetime object
248255
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
249256
# This results in logs being searched in the time range in which the correct log line was not present.
250257
# Example - Log time - 2018-10-22 08:25:55
251258
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
252259
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log.
253-
end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
260+
end_time = self._end_time or description.get(
261+
u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
262+
254263
return {
255264
'start_time': start_time,
256265
'end_time': end_time,
@@ -275,7 +284,7 @@ def _fetch_metric(self, metric_name):
275284
],
276285
'StartTime': self._time_interval['start_time'],
277286
'EndTime': self._time_interval['end_time'],
278-
'Period': 60,
287+
'Period': self._period,
279288
'Statistics': ['Average'],
280289
}
281290
raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints']

tests/integ/test_tf_script_mode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_mnist(sagemaker_session, instance_type):
4747
sagemaker_session=sagemaker_session,
4848
py_version='py3',
4949
framework_version=TensorFlow.LATEST_VERSION,
50+
metric_definitions=[{'Name': 'train:global_steps', 'Regex': r'global_step\/sec:\s(.*)'}],
5051
base_job_name='test-tf-sm-mnist')
5152
inputs = estimator.sagemaker_session.upload_data(
5253
path=os.path.join(RESOURCE_PATH, 'data'),
@@ -56,6 +57,9 @@ def test_mnist(sagemaker_session, instance_type):
5657
estimator.fit(inputs)
5758
_assert_s3_files_exist(estimator.model_dir,
5859
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta'])
60+
df = estimator.training_job_analytics.dataframe()
61+
print(df)
62+
assert df.size > 0
5963

6064

6165
def test_server_side_encryption(sagemaker_session):

tests/unit/test_analytics.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,20 @@ def test_trainer_dataframe():
245245
trainer.export_csv(tmp_name)
246246
assert os.path.isfile(tmp_name)
247247
os.unlink(tmp_name)
248+
249+
250+
def test_start_time_end_time_and_period_specified():
251+
describe_training_result = {
252+
'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3),
253+
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
254+
}
255+
session = create_sagemaker_session(describe_training_result)
256+
start_time = datetime.datetime(2018, 5, 16, 1, 3, 4)
257+
end_time = datetime.datetime(2018, 5, 16, 5, 1, 1)
258+
period = 300
259+
trainer = TrainingJobAnalytics('my-training-job', ['metric'],
260+
sagemaker_session=session, start_time=start_time, end_time=end_time, period=period)
261+
262+
assert trainer._time_interval['start_time'] == start_time
263+
assert trainer._time_interval['end_time'] == end_time
264+
assert trainer._period == period

0 commit comments

Comments
 (0)