Skip to content

Commit 0647340

Browse files
justusschockBordaakihironitta
authored
Add mypy typing to precision plugins. (#6149)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent e7298b5 commit 0647340

File tree

11 files changed

+134
-90
lines changed

11 files changed

+134
-90
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def amp_backend(self) -> Optional[LightningEnum]:
366366
return None
367367

368368
@property
369-
def precision(self) -> int:
369+
def precision(self) -> Union[str, int]:
370370
return self.precision_plugin.precision
371371

372372
@property

pytorch_lightning/plugins/base_plugin.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import contextlib
15-
from abc import ABC, abstractmethod
16-
from typing import Generator, Optional, Sequence, Tuple
17-
18-
from torch.nn import Module
15+
from abc import ABC
16+
from typing import Generator
1917

2018

2119
class Plugin(ABC):
2220
"""Basic Plugin class to derive precision and training type plugins from."""
2321

24-
@abstractmethod
25-
def connect(
26-
self,
27-
model: Module,
28-
*args: Sequence,
29-
**kwargs: Sequence,
30-
) -> Optional[Tuple[Module, Sequence, Sequence]]:
31-
"""Connects the plugin with the accelerator (and thereby with trainer and model).
32-
Will be called by the accelerator.
33-
"""
34-
3522
def pre_dispatch(self) -> None:
3623
"""Hook to do something before the training/evaluation/prediction starts."""
3724

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Callable, List, Tuple
14+
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING
1515

1616
import torch
17-
from torch.optim import Optimizer
1817

1918
from pytorch_lightning.core import LightningModule
2019
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
@@ -23,37 +22,41 @@
2322
if _APEX_AVAILABLE:
2423
from apex import amp
2524

25+
if TYPE_CHECKING:
26+
from torch.optim import Optimizer
27+
2628

2729
class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
2830
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""
2931

30-
def __init__(self, amp_level: str):
32+
def __init__(self, amp_level: str) -> None:
3133
self.backend = AMPType.APEX
3234
self.amp_level = amp_level
3335

34-
def master_params(self, optimizer: torch.optim.Optimizer):
36+
def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
3537
return amp.master_params(optimizer)
3638

37-
def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
39+
def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'],
40+
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]:
3841
"""Connects the precision plugin to the training process,
3942
configures apex and reinits the schedulers
4043
"""
4144
if model.device.type != "cuda":
4245
return model, optimizers, lr_schedulers
43-
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
46+
model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level)
4447
self.reinit_scheduler_properties(optimizers, lr_schedulers)
4548
return model, optimizers, lr_schedulers
4649

4750
def backward(
4851
self,
4952
model: LightningModule,
5053
closure_loss: torch.Tensor,
51-
optimizer: torch.optim.Optimizer,
54+
optimizer: 'Optimizer',
5255
opt_idx: int,
5356
should_accumulate: bool,
54-
*args,
55-
**kwargs,
56-
):
57+
*args: Any,
58+
**kwargs: Any,
59+
) -> torch.Tensor:
5760
"""performs the actual backpropagation
5861
5962
Args:
@@ -94,11 +97,11 @@ def backward(
9497

9598
def configure_apex(
9699
self,
97-
amp: object,
100+
amp: Type,
98101
model: LightningModule,
99-
optimizers: List[Optimizer],
102+
optimizers: List['Optimizer'],
100103
amp_level: str,
101-
) -> Tuple[LightningModule, List[Optimizer]]:
104+
) -> Tuple[LightningModule, List['Optimizer']]:
102105
r"""
103106
Override to init AMP your own way.
104107
Must return a model and list of optimizers.
@@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level):
127130
return model, optimizers
128131

129132
@staticmethod
130-
def reinit_scheduler_properties(optimizers: list, schedulers: list):
133+
def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None:
131134
"""Reinitializes schedulers with correct properties"""
132135
# Reinitialize optimizer.step properties added by schedulers
133136
for scheduler in schedulers:
@@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list):
149152
break
150153

151154
def pre_optimizer_step(
152-
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
155+
self,
156+
pl_module: LightningModule,
157+
optimizer: 'Optimizer',
158+
optimizer_idx: int,
159+
lambda_closure: Callable,
160+
**kwargs: Any,
153161
) -> bool:
154162
"""
155163
always called before the optimizer step.
@@ -160,6 +168,6 @@ def pre_optimizer_step(
160168
if not pl_module.automatic_optimization:
161169
pl_module.trainer.call_hook("on_after_backward")
162170

163-
optimizer.step()
171+
optimizer.step(**kwargs)
164172

165173
return False

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,35 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Callable, Union
14+
from typing import Any, Callable, TYPE_CHECKING, Union
1515

1616
import torch
17-
from torch.optim import Optimizer
1817

19-
from pytorch_lightning.core.lightning import LightningModule
2018
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2119
from pytorch_lightning.utilities.model_helpers import is_overridden
2220
from pytorch_lightning.utilities.warnings import WarningCache
2321

22+
if TYPE_CHECKING:
23+
from torch.optim import Optimizer
24+
25+
from pytorch_lightning.core.lightning import LightningModule
26+
2427
warning_cache = WarningCache()
2528

2629

2730
class DeepSpeedPrecisionPlugin(PrecisionPlugin):
2831

29-
def __init__(self, precision):
32+
def __init__(self, precision: int) -> None:
3033
super().__init__()
3134
self.precision = precision
3235

3336
def pre_optimizer_step(
34-
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
37+
self,
38+
pl_module: 'LightningModule',
39+
optimizer: 'Optimizer',
40+
optimizer_idx: int,
41+
lambda_closure: Callable,
42+
**kwargs: Any,
3543
) -> bool:
3644
deepspeed_engine = pl_module.trainer.model
3745
# DeepSpeed not support closures.
@@ -46,28 +54,30 @@ def pre_optimizer_step(
4654

4755
def backward(
4856
self,
49-
lightning_module: LightningModule,
57+
model: 'LightningModule',
5058
closure_loss: torch.Tensor,
51-
optimizer: torch.optim.Optimizer,
59+
optimizer: 'Optimizer',
5260
opt_idx: int,
5361
should_accumulate: bool,
54-
*args,
55-
**kwargs,
56-
):
57-
if is_overridden('backward', lightning_module):
62+
*args: Any,
63+
**kwargs: Any,
64+
) -> torch.Tensor:
65+
if is_overridden('backward', model):
5866
warning_cache.warn(
5967
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
6068
"backward logic outside of the LightningModule"
6169
)
6270
# todo: hack around for deepspeed engine to call backward
63-
deepspeed_engine = lightning_module.trainer.model
64-
deepspeed_engine.backward(closure_loss, **kwargs)
71+
deepspeed_engine = model.trainer.model
72+
deepspeed_engine.backward(closure_loss, *args, **kwargs)
6573
# once backward has been applied, release graph
6674
closure_loss = closure_loss.detach()
6775

6876
return closure_loss
6977

70-
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
78+
def clip_gradients(
79+
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
80+
) -> None:
7181
"""
7282
DeepSpeed handles clipping gradients via the training type plugin.
7383
"""

pytorch_lightning/plugins/precision/mixed.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import TYPE_CHECKING, Union
15+
1416
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
15-
from pytorch_lightning.utilities import AMPType
17+
18+
if TYPE_CHECKING:
19+
from pytorch_lightning.utilities import AMPType
1620

1721

1822
class MixedPrecisionPlugin(PrecisionPlugin):
1923
"""Base Class for mixed precision"""
2024

21-
EPSILON = 1e-5
22-
backend: AMPType
23-
precision = "mixed"
25+
EPSILON: float = 1e-5
26+
backend: 'AMPType'
27+
precision: Union[str, int] = "mixed"

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,36 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Callable, Generator
15+
from typing import Any, Callable, Generator, TYPE_CHECKING
1616

1717
import torch
18-
from torch.optim import LBFGS, Optimizer
18+
from torch.optim import LBFGS
1919

20-
from pytorch_lightning.core import LightningModule
2120
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
22-
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
21+
from pytorch_lightning.utilities import AMPType
2322
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2423

25-
if _NATIVE_AMP_AVAILABLE:
26-
from torch.cuda.amp import autocast
27-
else:
28-
autocast = None
24+
if TYPE_CHECKING:
25+
from torch.optim import Optimizer
26+
27+
from pytorch_lightning.core import LightningModule
2928

3029

3130
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
3231

33-
def __init__(self):
32+
def __init__(self) -> None:
3433
self.backend = AMPType.NATIVE
3534
self.scaler = torch.cuda.amp.GradScaler()
3635

3736
def backward(
3837
self,
39-
model: LightningModule,
38+
model: 'LightningModule',
4039
closure_loss: torch.Tensor,
41-
optimizer: Optimizer,
40+
optimizer: 'Optimizer',
4241
opt_idx: int,
4342
should_accumulate: bool,
44-
*args,
45-
**kwargs,
43+
*args: Any,
44+
**kwargs: Any,
4645
) -> torch.Tensor:
4746
"""performs the actual backpropagation
4847
@@ -65,7 +64,12 @@ def backward(
6564
return closure_loss
6665

6766
def pre_optimizer_step(
68-
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
67+
self,
68+
pl_module: 'LightningModule',
69+
optimizer: 'Optimizer',
70+
optimizer_idx: int,
71+
lambda_closure: Callable,
72+
**kwargs: Any,
6973
) -> bool:
7074
"""always called before the optimizer step.
7175
Checks that the optimizer is not LBFGS, as this one is not supported by native amp
@@ -83,13 +87,13 @@ def pre_optimizer_step(
8387

8488
return False
8589

86-
def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
90+
def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> None:
8791
"""Updates the GradScaler"""
8892
self.scaler.step(optimizer)
8993
self.scaler.update()
9094

9195
@contextmanager
92-
def train_step_context(self) -> Generator[autocast, None, None]:
96+
def train_step_context(self) -> Generator[None, None, None]:
9397
"""Enable autocast context"""
9498
with torch.cuda.amp.autocast():
9599
yield

0 commit comments

Comments
 (0)