Skip to content

Commit 6c3fb39

Browse files
committed
rm EvalModel
1 parent 3c87cc4 commit 6c3fb39

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

tests/accelerators/test_ddp_spawn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pytorch_lightning.core import memory
2121
from pytorch_lightning.trainer import Trainer
2222
from pytorch_lightning.trainer.states import TrainerState
23-
from tests.base import EvalModelTemplate
2423
from tests.helpers import BoringModel
2524
from tests.helpers.datamodules import ClassifDataModule
2625
from tests.helpers.simple_models import ClassificationModel
@@ -72,7 +71,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
7271
"""Make sure DDP works with dataloaders passed to fit()"""
7372
tutils.set_random_master_port()
7473

75-
model = EvalModelTemplate()
74+
model = BoringModel()
7675
fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
7776

7877
trainer = Trainer(

tests/accelerators/test_dp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import tests.helpers.utils as tutils
2020
from pytorch_lightning.callbacks import EarlyStopping
2121
from pytorch_lightning.core import memory
22-
from tests.base import EvalModelTemplate
2322
from tests.helpers import BoringModel
2423
from tests.helpers.datamodules import ClassifDataModule
2524
from tests.helpers.simple_models import ClassificationModel
@@ -76,7 +75,8 @@ def test_dp_test(tmpdir):
7675
import os
7776
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
7877

79-
model = EvalModelTemplate()
78+
dm = ClassifDataModule()
79+
model = ClassificationModel()
8080
trainer = pl.Trainer(
8181
default_root_dir=tmpdir,
8282
max_epochs=2,
@@ -85,14 +85,14 @@ def test_dp_test(tmpdir):
8585
gpus=[0, 1],
8686
accelerator='dp',
8787
)
88-
trainer.fit(model)
88+
trainer.fit(model, datamodule=dm)
8989
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
90-
results = trainer.test()
90+
results = trainer.test(datamodule=dm)
9191
assert 'test_acc' in results[0]
9292

9393
old_weights = model.c_d1.weight.clone().detach().cpu()
9494

95-
results = trainer.test(model)
95+
results = trainer.test(model, datamodule=dm)
9696
assert 'test_acc' in results[0]
9797

9898
# make sure weights didn't change

0 commit comments

Comments
 (0)