Skip to content

Commit 3ed8ef8

Browse files
authored
type accelerators (#6148)
1 parent b0d1996 commit 3ed8ef8

File tree

6 files changed

+81
-49
lines changed

6 files changed

+81
-49
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Iterable, Optional, Union
14+
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
1515

1616
import torch
1717
from torch.optim import Optimizer
@@ -24,6 +24,14 @@
2424
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
2525
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
2626

27+
if TYPE_CHECKING:
28+
from torch.cuda.amp import GradScaler
29+
30+
from pytorch_lightning.trainer.trainer import Trainer
31+
32+
33+
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
34+
2735

2836
class Accelerator(object):
2937
"""
@@ -54,11 +62,11 @@ def __init__(
5462
self.precision_plugin = precision_plugin
5563
self.training_type_plugin = training_type_plugin
5664

57-
self.optimizers = None
58-
self.lr_schedulers = None
59-
self.optimizer_frequencies = None
65+
self.optimizers: Sequence = []
66+
self.lr_schedulers: Sequence = []
67+
self.optimizer_frequencies: Sequence = []
6068

61-
def setup(self, trainer, model: LightningModule) -> None:
69+
def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
6270
"""
6371
Connects the plugins to the training process, creates optimizers
6472
@@ -70,13 +78,13 @@ def setup(self, trainer, model: LightningModule) -> None:
7078
self.setup_optimizers(trainer)
7179
self.connect_precision_plugin(self.precision_plugin)
7280

73-
def start_training(self, trainer):
81+
def start_training(self, trainer: 'Trainer') -> None:
7482
self.training_type_plugin.start_training(trainer)
7583

76-
def start_testing(self, trainer):
84+
def start_testing(self, trainer: 'Trainer') -> None:
7785
self.training_type_plugin.start_testing(trainer)
7886

79-
def start_predicting(self, trainer):
87+
def start_predicting(self, trainer: 'Trainer') -> None:
8088
self.training_type_plugin.start_predicting(trainer)
8189

8290
def pre_dispatch(self) -> None:
@@ -113,7 +121,7 @@ def lightning_module(self) -> LightningModule:
113121
def root_device(self) -> torch.device:
114122
return self.training_type_plugin.root_device
115123

116-
def teardown(self):
124+
def teardown(self) -> None:
117125
"""This method is called to teardown the training process.
118126
It is the right place to release memory and free other ressources.
119127
"""
@@ -134,11 +142,14 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) ->
134142

135143
return move_data_to_device(batch, device)
136144

137-
def on_train_start(self):
145+
def on_train_start(self) -> None:
138146
"""Hook to do something upon the training start"""
139147
pass
140148

141-
def training_step(self, args):
149+
def training_step(
150+
self,
151+
args: List[Union[Any, int]],
152+
) -> _STEP_OUTPUT_TYPE:
142153
"""The actual training step.
143154
144155
Args:
@@ -156,10 +167,10 @@ def training_step(self, args):
156167
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
157168
return self.training_type_plugin.training_step(*args)
158169

159-
def post_training_step(self):
170+
def post_training_step(self) -> None:
160171
self.training_type_plugin.post_training_step()
161172

162-
def validation_step(self, args):
173+
def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
163174
"""The actual validation step.
164175
165176
Args:
@@ -177,7 +188,7 @@ def validation_step(self, args):
177188
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
178189
return self.training_type_plugin.validation_step(*args)
179190

180-
def test_step(self, args):
191+
def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
181192
"""The actual test step.
182193
183194
Args:
@@ -195,7 +206,7 @@ def test_step(self, args):
195206
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
196207
return self.training_type_plugin.test_step(*args)
197208

198-
def predict(self, args):
209+
def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
199210
"""The actual predict step.
200211
201212
Args:
@@ -213,23 +224,29 @@ def predict(self, args):
213224
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
214225
return self.training_type_plugin.predict(*args)
215226

216-
def training_step_end(self, output):
227+
def training_step_end(
228+
self, output: _STEP_OUTPUT_TYPE
229+
) -> _STEP_OUTPUT_TYPE:
217230
"""A hook to do something at the end of the training step
218231
219232
Args:
220233
output: the output of the training step
221234
"""
222235
return self.training_type_plugin.training_step_end(output)
223236

224-
def test_step_end(self, output):
237+
def test_step_end(
238+
self, output: _STEP_OUTPUT_TYPE
239+
) -> _STEP_OUTPUT_TYPE:
225240
"""A hook to do something at the end of the test step
226241
227242
Args:
228243
output: the output of the test step
229244
"""
230245
return self.training_type_plugin.test_step_end(output)
231246

232-
def validation_step_end(self, output):
247+
def validation_step_end(
248+
self, output: _STEP_OUTPUT_TYPE
249+
) -> _STEP_OUTPUT_TYPE:
233250
"""A hook to do something at the end of the validation step
234251
235252
Args:
@@ -243,8 +260,8 @@ def backward(
243260
optimizer: Optimizer,
244261
optimizer_idx: int,
245262
should_accumulate: bool,
246-
*args,
247-
**kwargs,
263+
*args: Any,
264+
**kwargs: Any,
248265
) -> torch.Tensor:
249266
"""Forwards backward-calls to the precision plugin.
250267
@@ -262,7 +279,7 @@ def backward(
262279

263280
return output
264281

265-
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
282+
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
266283
"""performs the actual optimizer step.
267284
268285
Args:
@@ -279,7 +296,9 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal
279296
self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
280297
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)
281298

282-
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
299+
def run_optimizer_step(
300+
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
301+
) -> None:
283302
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
284303

285304
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
@@ -292,7 +311,7 @@ def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> N
292311

293312
self.precision_plugin.clip_gradients(optimizer, clip_val)
294313

295-
def on_train_epoch_end(self, outputs) -> None:
314+
def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
296315
"""Hook to do something on the end of an training epoch
297316
298317
Args:
@@ -304,7 +323,7 @@ def on_train_end(self) -> None:
304323
"""Hook to do something at the end of the training"""
305324
pass
306325

307-
def setup_optimizers(self, trainer):
326+
def setup_optimizers(self, trainer: 'Trainer') -> None:
308327
"""creates optimizers and schedulers
309328
310329
Args:
@@ -327,7 +346,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn
327346
"""
328347
plugin.connect(model)
329348

330-
def connect_precision_plugin(self, plugin: PrecisionPlugin):
349+
def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
331350
"""Attaches the precision plugin to the accelerator"""
332351
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
333352
self.model = model
@@ -351,26 +370,22 @@ def precision(self) -> int:
351370
return self.precision_plugin.precision
352371

353372
@property
354-
def scaler(self):
355-
if hasattr(self.precision_plugin, "scaler"):
356-
return self.precision_plugin.scaler
373+
def scaler(self) -> Optional['GradScaler']:
357374

358-
return None
375+
return getattr(self.precision_plugin, 'scaler', None)
359376

360377
@property
361378
def rpc_enabled(self) -> bool:
362379
return self.training_type_plugin.rpc_enabled
363380

364-
def optimizer_state(self, optimizer: Optimizer) -> dict:
381+
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]:
365382
"""
366383
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
367384
plugins.
368385
"""
369-
if self.training_type_plugin and hasattr(self.training_type_plugin, "optimizer_state"):
370-
return self.training_type_plugin.optimizer_state(optimizer)
371-
return optimizer.state_dict()
386+
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)
372387

373-
def on_save(self, checkpoint):
388+
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
374389
return checkpoint
375390

376391
def barrier(self, name: Optional[str] = None) -> None:
@@ -385,7 +400,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
385400
"""
386401
return self.training_type_plugin.broadcast(obj, src)
387402

388-
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
403+
def all_gather(
404+
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
405+
) -> torch.Tensor:
389406
"""
390407
Function to gather a tensor from several distributed processes.
391408

pytorch_lightning/accelerators/cpu.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
from typing import TYPE_CHECKING
2+
13
from pytorch_lightning.accelerators.accelerator import Accelerator
24
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
35
from pytorch_lightning.utilities.exceptions import MisconfigurationException
46

7+
if TYPE_CHECKING:
8+
from pytorch_lightning.core.lightning import LightningModule
9+
from pytorch_lightning.trainer.trainer import Trainer
10+
511

612
class CPUAccelerator(Accelerator):
713

8-
def setup(self, trainer, model):
14+
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
915
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
1016
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")
1117

pytorch_lightning/accelerators/gpu.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,43 @@
11
import logging
22
import os
3+
from typing import TYPE_CHECKING
34

45
import torch
56

67
from pytorch_lightning.accelerators.accelerator import Accelerator
78
from pytorch_lightning.utilities.exceptions import MisconfigurationException
89

10+
if TYPE_CHECKING:
11+
from pytorch_lightning.core.lightning import LightningModule
12+
from pytorch_lightning.trainer.trainer import Trainer
13+
914
_log = logging.getLogger(__name__)
1015

1116

1217
class GPUAccelerator(Accelerator):
1318

14-
def setup(self, trainer, model):
19+
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
1520
if "cuda" not in str(self.root_device):
1621
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
1722
self.set_nvidia_flags()
1823
torch.cuda.set_device(self.root_device)
1924
return super().setup(trainer, model)
2025

21-
def on_train_start(self):
26+
def on_train_start(self) -> None:
2227
# clear cache before training
2328
# use context because of:
2429
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
2530
with torch.cuda.device(self.root_device):
2631
torch.cuda.empty_cache()
2732

28-
def on_train_end(self):
33+
def on_train_end(self) -> None:
2934
# clean up memory
3035
self.model.cpu()
3136
with torch.cuda.device(self.root_device):
3237
torch.cuda.empty_cache()
3338

3439
@staticmethod
35-
def set_nvidia_flags():
40+
def set_nvidia_flags() -> None:
3641
# set the correct cuda visible devices (using pci order)
3742
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3843
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])

pytorch_lightning/accelerators/tpu.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Optional, Union
1+
from typing import Any, Callable, Optional, TYPE_CHECKING
22

33
import torch
44
from torch.optim import Optimizer
@@ -13,10 +13,14 @@
1313
if _XLA_AVAILABLE:
1414
import torch_xla.core.xla_model as xm
1515

16+
if TYPE_CHECKING:
17+
from pytorch_lightning.core.lightning import LightningModule
18+
from pytorch_lightning.trainer.trainer import Trainer
19+
1620

1721
class TPUAccelerator(Accelerator):
1822

19-
def setup(self, trainer, model):
23+
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
2024
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
2125
raise MisconfigurationException(
2226
"amp + tpu is not supported. "
@@ -27,10 +31,14 @@ def setup(self, trainer, model):
2731
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
2832
return super().setup(trainer, model)
2933

30-
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
34+
def run_optimizer_step(
35+
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
36+
) -> None:
3137
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
3238

33-
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
39+
def all_gather(
40+
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
41+
) -> torch.Tensor:
3442
"""
3543
Function to gather a tensor from several distributed processes
3644
Args:

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def model(self, new_model: Module) -> None:
9696
self._model = new_model
9797

9898
@property
99-
def lightning_module(self) -> Optional[LightningModule]:
99+
def lightning_module(self) -> LightningModule:
100100
"""Returns the pure LightningModule without potential wrappers"""
101101
return unwrap_lightning_module(self._model)
102102

setup.cfg

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,6 @@ warn_redundant_casts = True
128128
warn_unused_configs = True
129129
warn_unused_ignores = True
130130

131-
# todo: add proper typing to this module...
132-
[mypy-pytorch_lightning.accelerators.*]
133-
ignore_errors = True
134-
135131
# todo: add proper typing to this module...
136132
[mypy-pytorch_lightning.callbacks.*]
137133
ignore_errors = True

0 commit comments

Comments
 (0)