Skip to content

Commit f5f4f03

Browse files
kaushikb11lexierule
authored andcommitted
Fix TPU tests for checkpoint
Skip advanced profiler for torch > 1.8 Skip pytorch profiler for torch > 1.8 Fix save checkpoint logic for TPUs
1 parent 123e20d commit f5f4f03

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,14 @@ def test_step(self, *args, **kwargs):
203203
def predict(self, *args, **kwargs):
204204
return self.lightning_module.predict(*args, **kwargs)
205205

206-
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
206+
def save_checkpoint(self, filepath: str, weights_only: bool = False) -> None:
207207
"""Save model/training states as a checkpoint file through state-dump and file-write.
208208
Args:
209-
checkpoint: dict containing model and trainer state
210209
filepath: write-target file's path
210+
weights_only: saving model weights only
211211
"""
212+
# dump states as a checkpoint dictionary object
213+
checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
212214
# Todo: TypeError: 'mappingproxy' object does not support item assignment
213215
if _OMEGACONF_AVAILABLE:
214216
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)

tests/models/test_tpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):
122122
progress_bar_refresh_rate=0,
123123
max_epochs=2,
124124
tpu_cores=1,
125-
limit_train_batches=8,
125+
limit_train_batches=0.7,
126126
limit_val_batches=2,
127127
)
128128

@@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir):
210210
progress_bar_refresh_rate=0,
211211
max_epochs=4,
212212
tpu_cores=1,
213-
limit_train_batches=0.4,
214-
limit_val_batches=0.4,
213+
limit_train_batches=10,
214+
limit_val_batches=10,
215215
gradient_clip_val=0.5,
216216
)
217217

tests/test_profiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121

2222
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
23+
from tests.helpers.runif import RunIf
2324

2425
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
2526

@@ -165,6 +166,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
165166
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE
166167

167168

169+
@RunIf(max_torch="1.8.1")
168170
def test_advanced_profiler_describe(tmpdir, advanced_profiler):
169171
"""
170172
ensure the profiler won't fail when reporting the summary

tests/trainer/test_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4343
from tests.base import EvalModelTemplate
4444
from tests.helpers import BoringModel, RandomDataset
45+
from tests.helpers.runif import RunIf
4546

4647

4748
@pytest.fixture
@@ -1499,6 +1500,7 @@ def test_trainer_predict_ddp_cpu(tmpdir):
14991500
predict(tmpdir, "ddp_cpu", 0, 2)
15001501

15011502

1503+
@RunIf(max_torch="1.8.1")
15021504
def test_pytorch_profiler_describe(pytorch_profiler):
15031505
"""Ensure the profiler won't fail when reporting the summary."""
15041506
with pytorch_profiler.profile("test_step"):

0 commit comments

Comments
 (0)