19
19
import tests .helpers .utils as tutils
20
20
from pytorch_lightning .callbacks import EarlyStopping
21
21
from pytorch_lightning .core import memory
22
- from tests .base import EvalModelTemplate
23
22
from tests .helpers import BoringModel
24
23
from tests .helpers .datamodules import ClassifDataModule
25
24
from tests .helpers .simple_models import ClassificationModel
@@ -76,7 +75,8 @@ def test_dp_test(tmpdir):
76
75
import os
77
76
os .environ ['CUDA_VISIBLE_DEVICES' ] = '0,1'
78
77
79
- model = EvalModelTemplate ()
78
+ dm = ClassifDataModule ()
79
+ model = ClassificationModel ()
80
80
trainer = pl .Trainer (
81
81
default_root_dir = tmpdir ,
82
82
max_epochs = 2 ,
@@ -85,14 +85,14 @@ def test_dp_test(tmpdir):
85
85
gpus = [0 , 1 ],
86
86
accelerator = 'dp' ,
87
87
)
88
- trainer .fit (model )
88
+ trainer .fit (model , datamodule = dm )
89
89
assert 'ckpt' in trainer .checkpoint_callback .best_model_path
90
- results = trainer .test ()
90
+ results = trainer .test (datamodule = dm )
91
91
assert 'test_acc' in results [0 ]
92
92
93
93
old_weights = model .c_d1 .weight .clone ().detach ().cpu ()
94
94
95
- results = trainer .test (model )
95
+ results = trainer .test (model , datamodule = dm )
96
96
assert 'test_acc' in results [0 ]
97
97
98
98
# make sure weights didn't change
0 commit comments