Skip to content

Commit 46b10fa

Browse files
jesterhazylaurenyu
authored andcommitted
fix test for pytest 4 (#482)
1 parent 10a8598 commit 46b10fa

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

tests/unit/test_analytics.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626

2727

2828
@pytest.fixture()
29-
def sagemaker_session(describe_training_result=None, list_training_results=None, metric_stats_results=None,
30-
describe_tuning_result=None):
29+
def sagemaker_session():
30+
return create_sagemaker_session()
31+
32+
33+
def create_sagemaker_session(describe_training_result=None, list_training_results=None, metric_stats_results=None,
34+
describe_tuning_result=None):
3135
boto_mock = Mock(name='boto_session', region_name=REGION)
3236
sms = Mock(name='sagemaker_session', boto_session=boto_mock,
3337
boto_region_name=REGION, config=None, local_mode=False)
@@ -77,7 +81,7 @@ def mock_summary(name="job-name", value=0.9):
7781
"layers": 137,
7882
},
7983
}
80-
session = sagemaker_session(list_training_results={
84+
session = create_sagemaker_session(list_training_results={
8185
"TrainingJobSummaries": [
8286
mock_summary(),
8387
mock_summary(),
@@ -116,7 +120,7 @@ def mock_summary(name="job-name", value=0.9):
116120

117121

118122
def test_description():
119-
session = sagemaker_session(describe_tuning_result={
123+
session = create_sagemaker_session(describe_tuning_result={
120124
'HyperParameterTuningJobConfig': {
121125
'ParameterRanges': {
122126
'CategoricalParameterRanges': [],
@@ -155,7 +159,7 @@ def test_trainer_name():
155159
'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3),
156160
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
157161
}
158-
session = sagemaker_session(describe_training_result)
162+
session = create_sagemaker_session(describe_training_result)
159163
trainer = TrainingJobAnalytics("my-training-job", ["metric"], sagemaker_session=session)
160164
assert trainer.name == "my-training-job"
161165
assert str(trainer).find("my-training-job") != -1
@@ -182,8 +186,8 @@ def test_trainer_dataframe():
182186
},
183187
]
184188
}
185-
session = sagemaker_session(describe_training_result=describe_training_result,
186-
metric_stats_results=metric_stats_results)
189+
session = create_sagemaker_session(describe_training_result=describe_training_result,
190+
metric_stats_results=metric_stats_results)
187191
trainer = TrainingJobAnalytics("my-training-job", ["train:acc"], sagemaker_session=session)
188192

189193
df = trainer.dataframe()

0 commit comments

Comments
 (0)