Skip to content

Add MyPy typing to accelerators #6148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 52 additions & 35 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
Expand All @@ -24,6 +24,14 @@
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum

if TYPE_CHECKING:
from torch.cuda.amp import GradScaler

from pytorch_lightning.trainer.trainer import Trainer


_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]


class Accelerator(object):
"""
Expand Down Expand Up @@ -54,11 +62,11 @@ def __init__(
self.precision_plugin = precision_plugin
self.training_type_plugin = training_type_plugin

self.optimizers = None
self.lr_schedulers = None
self.optimizer_frequencies = None
self.optimizers: Sequence = []
self.lr_schedulers: Sequence = []
self.optimizer_frequencies: Sequence = []

def setup(self, trainer, model: LightningModule) -> None:
def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
"""
Connects the plugins to the training process, creates optimizers

Expand All @@ -70,13 +78,13 @@ def setup(self, trainer, model: LightningModule) -> None:
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)

def start_training(self, trainer):
def start_training(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_training(trainer)

def start_testing(self, trainer):
def start_testing(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_testing(trainer)

def start_predicting(self, trainer):
def start_predicting(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_predicting(trainer)

def pre_dispatch(self) -> None:
Expand Down Expand Up @@ -113,7 +121,7 @@ def lightning_module(self) -> LightningModule:
def root_device(self) -> torch.device:
return self.training_type_plugin.root_device

def teardown(self):
def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other ressources.
"""
Expand All @@ -134,11 +142,14 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) ->

return move_data_to_device(batch, device)

def on_train_start(self):
def on_train_start(self) -> None:
"""Hook to do something upon the training start"""
pass

def training_step(self, args):
def training_step(
self,
args: List[Union[Any, int]],
) -> _STEP_OUTPUT_TYPE:
"""The actual training step.

Args:
Expand All @@ -156,10 +167,10 @@ def training_step(self, args):
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)

def post_training_step(self):
def post_training_step(self) -> None:
self.training_type_plugin.post_training_step()

def validation_step(self, args):
def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
"""The actual validation step.

Args:
Expand All @@ -177,7 +188,7 @@ def validation_step(self, args):
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args)

def test_step(self, args):
def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
"""The actual test step.

Args:
Expand All @@ -195,7 +206,7 @@ def test_step(self, args):
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)

def predict(self, args):
def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
"""The actual predict step.

Args:
Expand All @@ -213,23 +224,29 @@ def predict(self, args):
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
return self.training_type_plugin.predict(*args)

def training_step_end(self, output):
def training_step_end(
self, output: _STEP_OUTPUT_TYPE
) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the training step

Args:
output: the output of the training step
"""
return self.training_type_plugin.training_step_end(output)

def test_step_end(self, output):
def test_step_end(
self, output: _STEP_OUTPUT_TYPE
) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the test step

Args:
output: the output of the test step
"""
return self.training_type_plugin.test_step_end(output)

def validation_step_end(self, output):
def validation_step_end(
self, output: _STEP_OUTPUT_TYPE
) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the validation step

Args:
Expand All @@ -243,8 +260,8 @@ def backward(
optimizer: Optimizer,
optimizer_idx: int,
should_accumulate: bool,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""Forwards backward-calls to the precision plugin.

Expand All @@ -262,7 +279,7 @@ def backward(

return output

def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
"""performs the actual optimizer step.

Args:
Expand All @@ -279,7 +296,9 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal
self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)

def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
Expand All @@ -292,7 +311,7 @@ def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> N

self.precision_plugin.clip_gradients(optimizer, clip_val)

def on_train_epoch_end(self, outputs) -> None:
def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
"""Hook to do something on the end of an training epoch

Args:
Expand All @@ -304,7 +323,7 @@ def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass

def setup_optimizers(self, trainer):
def setup_optimizers(self, trainer: 'Trainer') -> None:
"""creates optimizers and schedulers

Args:
Expand All @@ -327,7 +346,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn
"""
plugin.connect(model)

def connect_precision_plugin(self, plugin: PrecisionPlugin):
def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
"""Attaches the precision plugin to the accelerator"""
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
self.model = model
Expand All @@ -351,26 +370,22 @@ def precision(self) -> int:
return self.precision_plugin.precision

@property
def scaler(self):
if hasattr(self.precision_plugin, "scaler"):
return self.precision_plugin.scaler
def scaler(self) -> Optional['GradScaler']:

return None
return getattr(self.precision_plugin, 'scaler', None)

@property
def rpc_enabled(self) -> bool:
return self.training_type_plugin.rpc_enabled

def optimizer_state(self, optimizer: Optimizer) -> dict:
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
plugins.
"""
if self.training_type_plugin and hasattr(self.training_type_plugin, "optimizer_state"):
return self.training_type_plugin.optimizer_state(optimizer)
return optimizer.state_dict()
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)

def on_save(self, checkpoint):
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
return checkpoint

def barrier(self, name: Optional[str] = None) -> None:
Expand All @@ -385,7 +400,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
"""
return self.training_type_plugin.broadcast(obj, src)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
def all_gather(
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes.

Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import TYPE_CHECKING

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer


class CPUAccelerator(Accelerator):

def setup(self, trainer, model):
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")

Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,43 @@
import logging
import os
from typing import TYPE_CHECKING

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer

_log = logging.getLogger(__name__)


class GPUAccelerator(Accelerator):

def setup(self, trainer, model):
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
self.set_nvidia_flags()
torch.cuda.set_device(self.root_device)
return super().setup(trainer, model)

def on_train_start(self):
def on_train_start(self) -> None:
# clear cache before training
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

def on_train_end(self):
def on_train_end(self) -> None:
# clean up memory
self.model.cpu()
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

@staticmethod
def set_nvidia_flags():
def set_nvidia_flags() -> None:
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
Expand Down
16 changes: 12 additions & 4 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING

import torch
from torch.optim import Optimizer
Expand All @@ -13,10 +13,14 @@
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm

if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer


class TPUAccelerator(Accelerator):

def setup(self, trainer, model):
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
"amp + tpu is not supported. "
Expand All @@ -27,10 +31,14 @@ def setup(self, trainer, model):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
def all_gather(
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def model(self, new_model: Module) -> None:
self._model = new_model

@property
def lightning_module(self) -> Optional[LightningModule]:
def lightning_module(self) -> LightningModule:
"""Returns the pure LightningModule without potential wrappers"""
return unwrap_lightning_module(self._model)

Expand Down
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,6 @@ warn_redundant_casts = True
warn_unused_configs = True
warn_unused_ignores = True

# todo: add proper typing to this module...
[mypy-pytorch_lightning.accelerators.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-pytorch_lightning.callbacks.*]
ignore_errors = True
Expand Down
3 changes: 1 addition & 2 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_unsupported_precision_plugins():
trainer = Mock()
model = Mock()
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
precision_plugin=MixedPrecisionPlugin()
training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)
Loading