@@ -47,12 +47,49 @@ def create_sagemaker_session(describe_training_result=None, list_training_result
47
47
cwm_mock = Mock (name = 'cloudwatch_client' )
48
48
boto_mock .client = Mock (return_value = cwm_mock )
49
49
cwm_mock .get_metric_statistics = Mock (
50
- name = 'get_metric_statistics' ,
51
- return_value = metric_stats_results ,
50
+ name = 'get_metric_statistics'
52
51
)
52
+ cwm_mock .get_metric_statistics .side_effect = cw_request_side_effect
53
53
return sms
54
54
55
55
56
+ def cw_request_side_effect (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
57
+ if _is_valid_request (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
58
+ return _metric_stats_results ()
59
+
60
+
61
+ def _is_valid_request (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
62
+ could_watch_request = {
63
+ 'Namespace' : Namespace ,
64
+ 'MetricName' : MetricName ,
65
+ 'Dimensions' : Dimensions ,
66
+ 'StartTime' : StartTime ,
67
+ 'EndTime' : EndTime ,
68
+ 'Period' : Period ,
69
+ 'Statistics' : Statistics ,
70
+ }
71
+ print (could_watch_request )
72
+ return could_watch_request == cw_request ()
73
+
74
+
75
+ def cw_request ():
76
+ describe_training_result = _describe_training_result ()
77
+ return {
78
+ 'Namespace' : '/aws/sagemaker/TrainingJobs' ,
79
+ 'MetricName' : 'train:acc' ,
80
+ 'Dimensions' : [
81
+ {
82
+ 'Name' : 'TrainingJobName' ,
83
+ 'Value' : 'my-training-job'
84
+ }
85
+ ],
86
+ 'StartTime' : describe_training_result ['TrainingStartTime' ],
87
+ 'EndTime' : describe_training_result ['TrainingEndTime' ] + datetime .timedelta (minutes = 1 ),
88
+ 'Period' : 60 ,
89
+ 'Statistics' : ['Average' ],
90
+ }
91
+
92
+
56
93
def test_abstract_base_class ():
57
94
# confirm that the abstract base class can't be instantiated directly
58
95
with pytest .raises (TypeError ) as _ : # noqa: F841
@@ -165,12 +202,15 @@ def test_trainer_name():
165
202
assert str (trainer ).find ("my-training-job" ) != - 1
166
203
167
204
168
- def test_trainer_dataframe ():
169
- describe_training_result = {
205
+ def _describe_training_result ():
206
+ return {
170
207
'TrainingStartTime' : datetime .datetime (2018 , 5 , 16 , 1 , 2 , 3 ),
171
208
'TrainingEndTime' : datetime .datetime (2018 , 5 , 16 , 5 , 6 , 7 ),
172
209
}
173
- metric_stats_results = {
210
+
211
+
212
+ def _metric_stats_results ():
213
+ return {
174
214
'Datapoints' : [
175
215
{
176
216
'Average' : 77.1 ,
@@ -186,8 +226,11 @@ def test_trainer_dataframe():
186
226
},
187
227
]
188
228
}
189
- session = create_sagemaker_session (describe_training_result = describe_training_result ,
190
- metric_stats_results = metric_stats_results )
229
+
230
+
231
+ def test_trainer_dataframe ():
232
+ session = create_sagemaker_session (describe_training_result = _describe_training_result (),
233
+ metric_stats_results = _metric_stats_results ())
191
234
trainer = TrainingJobAnalytics ("my-training-job" , ["train:acc" ], sagemaker_session = session )
192
235
193
236
df = trainer .dataframe ()
0 commit comments