Skip to content

Commit f3a08dd

Browse files
authored
Update test_catalyst. (#1134)
1 parent b055d91 commit f3a08dd

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

tests/test_catalyst.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,18 @@ def test_mnist(self):
141141
logdir=logdir,
142142
num_epochs=num_epochs,
143143
verbose=False,
144-
callbacks=[CheckpointCallback(save_n_best=3, use_runner_logdir=True)]
144+
callbacks=[CheckpointCallback(
145+
logdir,
146+
topk=3,
147+
save_best=True,
148+
loader_key="valid",
149+
metric_key="loss",
150+
minimize=True)]
145151
)
146-
147-
with open('./logs/_metrics.json') as f:
152+
153+
with open('./logs/model.storage.json') as f:
148154
metrics = json.load(f)
149-
self.assertTrue(metrics['train.3']['valid']['loss'] < metrics['train.1']['valid']['loss'])
150-
self.assertTrue(metrics['best']['valid']['loss'] < 0.35)
155+
storage = metrics['storage']
156+
self.assertEqual(3, len(storage))
157+
self.assertTrue(storage[0]['metric'] < storage[2]['metric'])
158+
self.assertTrue(storage[0]['metric']< 0.35)

0 commit comments

Comments
 (0)