Skip to content

Commit 6166f46

Browse files
authored
drop unused variable in API (Lightning-AI#6308)
* drop unused pl model in ckpt * irelevant * on_evaluation_batch_start * evaluation_epoch_end * attach_datamodule
1 parent 484dce1 commit 6166f46

File tree

9 files changed

+22
-33
lines changed

9 files changed

+22
-33
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def save_checkpoint(self, trainer, pl_module):
239239
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)
240240

241241
# Mode 2: save the last checkpoint
242-
self._save_last_checkpoint(trainer, pl_module, monitor_candidates)
242+
self._save_last_checkpoint(trainer, monitor_candidates)
243243

244244
def __validate_init_configuration(self):
245245
if self.save_top_k is not None and self.save_top_k < -1:
@@ -291,8 +291,7 @@ def _del_model(self, filepath: str):
291291
self._fs.rm(filepath)
292292
log.debug(f"Removed checkpoint: {filepath}")
293293

294-
def _save_model(self, filepath: str, trainer, pl_module):
295-
# Todo: required argument `pl_module` is not used
294+
def _save_model(self, filepath: str, trainer):
296295
# in debugging, track when we save checkpoints
297296
trainer.dev_debugger.track_checkpointing_history(filepath)
298297

@@ -481,7 +480,7 @@ def _monitor_candidates(self, trainer):
481480
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
482481
return monitor_candidates
483482

484-
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
483+
def _save_last_checkpoint(self, trainer, ckpt_name_metrics):
485484
should_save_last = self.monitor is None or self.save_last
486485
if not should_save_last:
487486
return
@@ -505,9 +504,9 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
505504

506505
if trainer.training_type_plugin.rpc_enabled:
507506
# RPCPlugin manages saving all model states
508-
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
507+
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer)
509508
else:
510-
self._save_model(last_filepath, trainer, pl_module)
509+
self._save_model(last_filepath, trainer)
511510
if (
512511
self.last_model_path and self.last_model_path != last_filepath
513512
and (self.save_top_k != -1 or self.save_last) and trainer.is_global_zero
@@ -574,7 +573,7 @@ def _update_best_and_save(
574573
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
575574
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
576575
)
577-
self._save_model(filepath, trainer, pl_module)
576+
self._save_model(filepath, trainer)
578577

579578
if del_filepath is not None and filepath != del_filepath:
580579
self._del_model(del_filepath)

pytorch_lightning/plugins/training_type/rpc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None:
6363
rpc._set_rpc_timeout(self.rpc_timeout_sec)
6464
self._is_rpc_initialized = True
6565

66-
def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
66+
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
6767
"""
6868
Override to save model to disk.
6969
This is required as the main process will be required to handle aggregating model states from RPC processes.
@@ -72,7 +72,6 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> No
7272
save_model_fn: The saving function to save final model.
7373
last_filepath: The filepath to save the model to.
7474
trainer: The trainer object.
75-
pl_module: The LightningModule.
7675
"""
7776
raise NotImplementedError
7877

pytorch_lightning/plugins/training_type/rpc_sequential.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,17 +266,17 @@ def configure_ddp(self):
266266
self._model.require_backward_grad_sync = False
267267

268268
@rank_zero_only
269-
def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
269+
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
270270
model = self.lightning_module
271271
if not hasattr(model.sequential_module, "foreach_worker"):
272272
return
273-
current_layers = pl_module.sequential_module
273+
current_layers = model.sequential_module
274274
model.sequential_module.foreach_worker(
275275
save_layers_on_all_rank_zero_workers, {"gpus_per_model": self.gpus_per_model}, include_self=True
276276
)
277-
pl_module.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
278-
save_model_fn(last_filepath, trainer, pl_module)
279-
pl_module.sequential_module = current_layers
277+
model.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
278+
save_model_fn(last_filepath, trainer)
279+
model.sequential_module = current_layers
280280

281281
def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None:
282282
model.sequential_module.foreach_worker(

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
8181

8282
# set up the passed in dataloaders (if needed)
8383
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
84-
self.attach_datamodule(model, datamodule, 'fit')
84+
self.attach_datamodule(model, datamodule)
8585

8686
def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
8787
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
@@ -112,8 +112,7 @@ def attach_dataloaders(
112112
if predict_dataloaders is not None:
113113
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)
114114

115-
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None:
116-
# Todo: required argument `stage` is not used
115+
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None:
117116

118117
# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
119118
datamodule = datamodule or getattr(model, 'datamodule', None)

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc
101101
current_hook_fx_name=hook_fx_name, on_step=on_step, on_epoch=on_epoch
102102
)
103103

104-
def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
105-
# Todo: required argument `testing` is not used
104+
def on_evaluation_batch_start(self, batch, dataloader_idx, num_dataloaders):
106105
model = self.trainer.lightning_module
107106
# set dataloader_idx only if multiple ones
108107
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
@@ -260,8 +259,7 @@ def track_metrics_deprecated(self, deprecated_eval_results):
260259
self._track_callback_metrics(deprecated_eval_results)
261260
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results)
262261

263-
def evaluation_epoch_end(self, testing):
264-
# Todo: required argument `testing` is not used
262+
def evaluation_epoch_end(self):
265263
# reset dataloader idx
266264
model_ref = self.trainer.lightning_module
267265
model_ref._current_dataloader_idx = None

pytorch_lightning/trainer/connectors/slurm_connector.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ def register_slurm_signal_handlers(self):
2828
signal.signal(signal.SIGTERM, self.term_handler)
2929

3030
def sig_handler(self, signum, frame): # pragma: no-cover
31-
# Todo: required argument `signum` is not used
32-
# Todo: required argument `frame` is not used
3331
if self.trainer.is_global_zero:
3432
# save weights
3533
log.info('handling SIGUSR1')
@@ -59,7 +57,5 @@ def sig_handler(self, signum, frame): # pragma: no-cover
5957
# close experiment to avoid issues
6058
self.trainer.logger.close()
6159

62-
def term_handler(self, signum, frame):
63-
# Todo: required argument `signum` is not used
64-
# Todo: required argument `frame` is not used
60+
def term_handler(self, signum, frame): # pragma: no-cover
6561
log.info("bypassing sigterm")

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def evaluation_step_end(self, *args, **kwargs):
181181

182182
def evaluation_epoch_end(self):
183183
# unset dataloder_idx in model
184-
self.trainer.logger_connector.evaluation_epoch_end(self.trainer.testing)
184+
self.trainer.logger_connector.evaluation_epoch_end()
185185

186186
# call the model epoch end
187187
deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders)
@@ -283,9 +283,7 @@ def _convert_to_numpy(v):
283283

284284
def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx):
285285
# set dataloader_idx to model and track batch_size
286-
self.trainer.logger_connector.on_evaluation_batch_start(
287-
self.trainer.testing, batch, dataloader_idx, self.num_dataloaders
288-
)
286+
self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders)
289287

290288
if self.trainer.testing:
291289
self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def test(
880880
)
881881

882882
# Attach datamodule to get setup/prepare_data added to model before the call to it below
883-
self.data_connector.attach_datamodule(model or self.lightning_module, datamodule, 'test')
883+
self.data_connector.attach_datamodule(model or self.lightning_module, datamodule)
884884

885885
if model is not None:
886886
results = self.__test_given_model(model, test_dataloaders)
@@ -989,7 +989,7 @@ def predict(
989989

990990
if datamodule is not None:
991991
# Attach datamodule to get setup/prepare_data added to model before the call to it below
992-
self.data_connector.attach_datamodule(model, datamodule, 'predict')
992+
self.data_connector.attach_datamodule(model, datamodule)
993993

994994
# attach data
995995
if dataloaders is not None:

tests/plugins/test_rpc_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, **kwargs):
5757
self.rpc_save_model_count = 0
5858
self.worker_optimizer_step_count = 0
5959

60-
def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
60+
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
6161
self.rpc_save_model_count += 1
6262

6363
def barrier(self, name: Optional[str] = None) -> None:

0 commit comments

Comments
 (0)