@@ -43,12 +43,49 @@ def sagemaker_session(describe_training_result=None, list_training_results=None,
43
43
cwm_mock = Mock (name = 'cloudwatch_client' )
44
44
boto_mock .client = Mock (return_value = cwm_mock )
45
45
cwm_mock .get_metric_statistics = Mock (
46
- name = 'get_metric_statistics' ,
47
- return_value = metric_stats_results ,
46
+ name = 'get_metric_statistics'
48
47
)
48
+ cwm_mock .get_metric_statistics .side_effect = cw_request_side_effect
49
49
return sms
50
50
51
51
52
+ def cw_request_side_effect (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
53
+ if _is_valid_request (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
54
+ return _metric_stats_results ()
55
+
56
+
57
+ def _is_valid_request (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
58
+ could_watch_request = {
59
+ 'Namespace' : Namespace ,
60
+ 'MetricName' : MetricName ,
61
+ 'Dimensions' : Dimensions ,
62
+ 'StartTime' : StartTime ,
63
+ 'EndTime' : EndTime ,
64
+ 'Period' : Period ,
65
+ 'Statistics' : Statistics ,
66
+ }
67
+ print (could_watch_request )
68
+ return could_watch_request == cw_request ()
69
+
70
+
71
+ def cw_request ():
72
+ describe_training_result = _describe_training_result ()
73
+ return {
74
+ 'Namespace' : '/aws/sagemaker/TrainingJobs' ,
75
+ 'MetricName' : 'train:acc' ,
76
+ 'Dimensions' : [
77
+ {
78
+ 'Name' : 'TrainingJobName' ,
79
+ 'Value' : 'my-training-job'
80
+ }
81
+ ],
82
+ 'StartTime' : describe_training_result ['TrainingStartTime' ],
83
+ 'EndTime' : describe_training_result ['TrainingEndTime' ] + datetime .timedelta (minutes = 1 ),
84
+ 'Period' : 60 ,
85
+ 'Statistics' : ['Average' ],
86
+ }
87
+
88
+
52
89
def test_abstract_base_class ():
53
90
# confirm that the abstract base class can't be instantiated directly
54
91
with pytest .raises (TypeError ) as _ : # noqa: F841
@@ -161,12 +198,15 @@ def test_trainer_name():
161
198
assert str (trainer ).find ("my-training-job" ) != - 1
162
199
163
200
164
- def test_trainer_dataframe ():
165
- describe_training_result = {
201
+ def _describe_training_result ():
202
+ return {
166
203
'TrainingStartTime' : datetime .datetime (2018 , 5 , 16 , 1 , 2 , 3 ),
167
204
'TrainingEndTime' : datetime .datetime (2018 , 5 , 16 , 5 , 6 , 7 ),
168
205
}
169
- metric_stats_results = {
206
+
207
+
208
+ def _metric_stats_results ():
209
+ return {
170
210
'Datapoints' : [
171
211
{
172
212
'Average' : 77.1 ,
@@ -182,8 +222,11 @@ def test_trainer_dataframe():
182
222
},
183
223
]
184
224
}
185
- session = sagemaker_session (describe_training_result = describe_training_result ,
186
- metric_stats_results = metric_stats_results )
225
+
226
+
227
+ def test_trainer_dataframe ():
228
+ session = sagemaker_session (describe_training_result = _describe_training_result (),
229
+ metric_stats_results = _metric_stats_results ())
187
230
trainer = TrainingJobAnalytics ("my-training-job" , ["train:acc" ], sagemaker_session = session )
188
231
189
232
df = trainer .dataframe ()
0 commit comments