Skip to content

Commit 5db2603

Browse files
piyushadlakhalaurenyu
authored andcommitted
Bug fix for getting dataframes in TrainingJobAnalytics. (#441)
1 parent 46b10fa commit 5db2603

File tree

3 files changed

+57
-8
lines changed

3 files changed

+57
-8
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ CHANGELOG
1515
* enhancement: Frameworks: update warning for not setting framework_version as we aren't planning a breaking change anymore
1616
* enhancement: Session: remove hardcoded 'training' from job status error message
1717
* bug-fix: Updated Cloudwatch namespace for metrics in TrainingJobsAnalytics
18+
* bug-fix: Changes to use correct s3 bucket and time range for dataframes in TrainingJobAnalytics.
1819

1920

2021
1.14.1

src/sagemaker/analytics.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,12 @@ def _determine_timeinterval(self):
246246
"""
247247
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
248248
start_time = description[u'TrainingStartTime'] # datetime object
249-
end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow())
249+
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
250+
# This results in logs being searched in the time range in which the correct log line was not present.
251+
# Example - Log time - 2018-10-22 08:25:55
252+
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
253+
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log.
254+
end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
250255
return {
251256
'start_time': start_time,
252257
'end_time': end_time,

tests/unit/test_analytics.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,49 @@ def create_sagemaker_session(describe_training_result=None, list_training_result
4747
cwm_mock = Mock(name='cloudwatch_client')
4848
boto_mock.client = Mock(return_value=cwm_mock)
4949
cwm_mock.get_metric_statistics = Mock(
50-
name='get_metric_statistics',
51-
return_value=metric_stats_results,
50+
name='get_metric_statistics'
5251
)
52+
cwm_mock.get_metric_statistics.side_effect = cw_request_side_effect
5353
return sms
5454

5555

56+
def cw_request_side_effect(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics):
57+
if _is_valid_request(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics):
58+
return _metric_stats_results()
59+
60+
61+
def _is_valid_request(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics):
62+
could_watch_request = {
63+
'Namespace': Namespace,
64+
'MetricName': MetricName,
65+
'Dimensions': Dimensions,
66+
'StartTime': StartTime,
67+
'EndTime': EndTime,
68+
'Period': Period,
69+
'Statistics': Statistics,
70+
}
71+
print(could_watch_request)
72+
return could_watch_request == cw_request()
73+
74+
75+
def cw_request():
76+
describe_training_result = _describe_training_result()
77+
return {
78+
'Namespace': '/aws/sagemaker/TrainingJobs',
79+
'MetricName': 'train:acc',
80+
'Dimensions': [
81+
{
82+
'Name': 'TrainingJobName',
83+
'Value': 'my-training-job'
84+
}
85+
],
86+
'StartTime': describe_training_result['TrainingStartTime'],
87+
'EndTime': describe_training_result['TrainingEndTime'] + datetime.timedelta(minutes=1),
88+
'Period': 60,
89+
'Statistics': ['Average'],
90+
}
91+
92+
5693
def test_abstract_base_class():
5794
# confirm that the abstract base class can't be instantiated directly
5895
with pytest.raises(TypeError) as _: # noqa: F841
@@ -165,12 +202,15 @@ def test_trainer_name():
165202
assert str(trainer).find("my-training-job") != -1
166203

167204

168-
def test_trainer_dataframe():
169-
describe_training_result = {
205+
def _describe_training_result():
206+
return {
170207
'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3),
171208
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
172209
}
173-
metric_stats_results = {
210+
211+
212+
def _metric_stats_results():
213+
return {
174214
'Datapoints': [
175215
{
176216
'Average': 77.1,
@@ -186,8 +226,11 @@ def test_trainer_dataframe():
186226
},
187227
]
188228
}
189-
session = create_sagemaker_session(describe_training_result=describe_training_result,
190-
metric_stats_results=metric_stats_results)
229+
230+
231+
def test_trainer_dataframe():
232+
session = create_sagemaker_session(describe_training_result=_describe_training_result(),
233+
metric_stats_results=_metric_stats_results())
191234
trainer = TrainingJobAnalytics("my-training-job", ["train:acc"], sagemaker_session=session)
192235

193236
df = trainer.dataframe()

0 commit comments

Comments
 (0)