26
26
27
27
28
28
@pytest .fixture ()
29
- def sagemaker_session (describe_training_result = None , list_training_results = None , metric_stats_results = None ,
30
- describe_tuning_result = None ):
29
+ def sagemaker_session ():
30
+ return create_sagemaker_session ()
31
+
32
+
33
+ def create_sagemaker_session (describe_training_result = None , list_training_results = None , metric_stats_results = None ,
34
+ describe_tuning_result = None ):
31
35
boto_mock = Mock (name = 'boto_session' , region_name = REGION )
32
36
sms = Mock (name = 'sagemaker_session' , boto_session = boto_mock ,
33
37
boto_region_name = REGION , config = None , local_mode = False )
@@ -77,7 +81,7 @@ def mock_summary(name="job-name", value=0.9):
77
81
"layers" : 137 ,
78
82
},
79
83
}
80
- session = sagemaker_session (list_training_results = {
84
+ session = create_sagemaker_session (list_training_results = {
81
85
"TrainingJobSummaries" : [
82
86
mock_summary (),
83
87
mock_summary (),
@@ -116,7 +120,7 @@ def mock_summary(name="job-name", value=0.9):
116
120
117
121
118
122
def test_description ():
119
- session = sagemaker_session (describe_tuning_result = {
123
+ session = create_sagemaker_session (describe_tuning_result = {
120
124
'HyperParameterTuningJobConfig' : {
121
125
'ParameterRanges' : {
122
126
'CategoricalParameterRanges' : [],
@@ -155,7 +159,7 @@ def test_trainer_name():
155
159
'TrainingStartTime' : datetime .datetime (2018 , 5 , 16 , 1 , 2 , 3 ),
156
160
'TrainingEndTime' : datetime .datetime (2018 , 5 , 16 , 5 , 6 , 7 ),
157
161
}
158
- session = sagemaker_session (describe_training_result )
162
+ session = create_sagemaker_session (describe_training_result )
159
163
trainer = TrainingJobAnalytics ("my-training-job" , ["metric" ], sagemaker_session = session )
160
164
assert trainer .name == "my-training-job"
161
165
assert str (trainer ).find ("my-training-job" ) != - 1
@@ -182,8 +186,8 @@ def test_trainer_dataframe():
182
186
},
183
187
]
184
188
}
185
- session = sagemaker_session (describe_training_result = describe_training_result ,
186
- metric_stats_results = metric_stats_results )
189
+ session = create_sagemaker_session (describe_training_result = describe_training_result ,
190
+ metric_stats_results = metric_stats_results )
187
191
trainer = TrainingJobAnalytics ("my-training-job" , ["train:acc" ], sagemaker_session = session )
188
192
189
193
df = trainer .dataframe ()
0 commit comments