Skip to content

Commit 42d2741

Browse files
author
Piyush Adlakha
committed
Unit tests for Bug fix for getting dataframes in TrainingJobAnalytics.
1 parent 8401afa commit 42d2741

File tree

1 file changed

+50
-7
lines changed

1 file changed

+50
-7
lines changed

tests/unit/test_analytics.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,49 @@ def sagemaker_session(describe_training_result=None, list_training_results=None,
4343
cwm_mock = Mock(name='cloudwatch_client')
4444
boto_mock.client = Mock(return_value=cwm_mock)
4545
cwm_mock.get_metric_statistics = Mock(
46-
name='get_metric_statistics',
47-
return_value=metric_stats_results,
46+
name='get_metric_statistics'
4847
)
48+
cwm_mock.get_metric_statistics.side_effect = cw_request_side_effect
4949
return sms
5050

5151

52+
def cw_request_side_effect(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics):
53+
if _is_valid_request(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics):
54+
return _metric_stats_results()
55+
56+
57+
def _is_valid_request(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics):
58+
could_watch_request = {
59+
'Namespace': Namespace,
60+
'MetricName': MetricName,
61+
'Dimensions': Dimensions,
62+
'StartTime': StartTime,
63+
'EndTime': EndTime,
64+
'Period': Period,
65+
'Statistics': Statistics,
66+
}
67+
print(could_watch_request)
68+
return could_watch_request == cw_request()
69+
70+
71+
def cw_request():
72+
describe_training_result = _describe_training_result()
73+
return {
74+
'Namespace': '/aws/sagemaker/TrainingJobs',
75+
'MetricName': 'train:acc',
76+
'Dimensions': [
77+
{
78+
'Name': 'TrainingJobName',
79+
'Value': 'my-training-job'
80+
}
81+
],
82+
'StartTime': describe_training_result['TrainingStartTime'],
83+
'EndTime': describe_training_result['TrainingEndTime'] + datetime.timedelta(minutes=1),
84+
'Period': 60,
85+
'Statistics': ['Average'],
86+
}
87+
88+
5289
def test_abstract_base_class():
5390
# confirm that the abstract base class can't be instantiated directly
5491
with pytest.raises(TypeError) as _: # noqa: F841
@@ -161,12 +198,15 @@ def test_trainer_name():
161198
assert str(trainer).find("my-training-job") != -1
162199

163200

164-
def test_trainer_dataframe():
165-
describe_training_result = {
201+
def _describe_training_result():
202+
return {
166203
'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3),
167204
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
168205
}
169-
metric_stats_results = {
206+
207+
208+
def _metric_stats_results():
209+
return {
170210
'Datapoints': [
171211
{
172212
'Average': 77.1,
@@ -182,8 +222,11 @@ def test_trainer_dataframe():
182222
},
183223
]
184224
}
185-
session = sagemaker_session(describe_training_result=describe_training_result,
186-
metric_stats_results=metric_stats_results)
225+
226+
227+
def test_trainer_dataframe():
228+
session = sagemaker_session(describe_training_result=_describe_training_result(),
229+
metric_stats_results=_metric_stats_results())
187230
trainer = TrainingJobAnalytics("my-training-job", ["train:acc"], sagemaker_session=session)
188231

189232
df = trainer.dataframe()

0 commit comments

Comments
 (0)