11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import Any , Callable , Iterable , Optional , Union
14
+ from typing import Any , Callable , Dict , Iterable , List , Optional , Sequence , TYPE_CHECKING , Union
15
15
16
16
import torch
17
17
from torch .optim import Optimizer
24
24
from pytorch_lightning .utilities .distributed import all_gather_ddp_if_available
25
25
from pytorch_lightning .utilities .enums import AMPType , LightningEnum
26
26
27
+ if TYPE_CHECKING :
28
+ from torch .cuda .amp import GradScaler
29
+
30
+ from pytorch_lightning .trainer .trainer import Trainer
31
+
32
+
33
+ _STEP_OUTPUT_TYPE = Union [torch .Tensor , Dict [str , torch .Tensor ], None ]
34
+
27
35
28
36
class Accelerator (object ):
29
37
"""
@@ -54,11 +62,11 @@ def __init__(
54
62
self .precision_plugin = precision_plugin
55
63
self .training_type_plugin = training_type_plugin
56
64
57
- self .optimizers = None
58
- self .lr_schedulers = None
59
- self .optimizer_frequencies = None
65
+ self .optimizers : Sequence = []
66
+ self .lr_schedulers : Sequence = []
67
+ self .optimizer_frequencies : Sequence = []
60
68
61
- def setup (self , trainer , model : LightningModule ) -> None :
69
+ def setup (self , trainer : 'Trainer' , model : LightningModule ) -> None :
62
70
"""
63
71
Connects the plugins to the training process, creates optimizers
64
72
@@ -70,13 +78,13 @@ def setup(self, trainer, model: LightningModule) -> None:
70
78
self .setup_optimizers (trainer )
71
79
self .connect_precision_plugin (self .precision_plugin )
72
80
73
- def start_training (self , trainer ) :
81
+ def start_training (self , trainer : 'Trainer' ) -> None :
74
82
self .training_type_plugin .start_training (trainer )
75
83
76
- def start_testing (self , trainer ) :
84
+ def start_testing (self , trainer : 'Trainer' ) -> None :
77
85
self .training_type_plugin .start_testing (trainer )
78
86
79
- def start_predicting (self , trainer ) :
87
+ def start_predicting (self , trainer : 'Trainer' ) -> None :
80
88
self .training_type_plugin .start_predicting (trainer )
81
89
82
90
def pre_dispatch (self ) -> None :
@@ -113,7 +121,7 @@ def lightning_module(self) -> LightningModule:
113
121
def root_device (self ) -> torch .device :
114
122
return self .training_type_plugin .root_device
115
123
116
- def teardown (self ):
124
+ def teardown (self ) -> None :
117
125
"""This method is called to teardown the training process.
118
126
It is the right place to release memory and free other ressources.
119
127
"""
@@ -134,11 +142,14 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) ->
134
142
135
143
return move_data_to_device (batch , device )
136
144
137
- def on_train_start (self ):
145
+ def on_train_start (self ) -> None :
138
146
"""Hook to do something upon the training start"""
139
147
pass
140
148
141
- def training_step (self , args ):
149
+ def training_step (
150
+ self ,
151
+ args : List [Union [Any , int ]],
152
+ ) -> _STEP_OUTPUT_TYPE :
142
153
"""The actual training step.
143
154
144
155
Args:
@@ -156,10 +167,10 @@ def training_step(self, args):
156
167
with self .precision_plugin .train_step_context (), self .training_type_plugin .train_step_context ():
157
168
return self .training_type_plugin .training_step (* args )
158
169
159
- def post_training_step (self ):
170
+ def post_training_step (self ) -> None :
160
171
self .training_type_plugin .post_training_step ()
161
172
162
- def validation_step (self , args ) :
173
+ def validation_step (self , args : List [ Union [ Any , int ]]) -> _STEP_OUTPUT_TYPE :
163
174
"""The actual validation step.
164
175
165
176
Args:
@@ -177,7 +188,7 @@ def validation_step(self, args):
177
188
with self .precision_plugin .val_step_context (), self .training_type_plugin .val_step_context ():
178
189
return self .training_type_plugin .validation_step (* args )
179
190
180
- def test_step (self , args ) :
191
+ def test_step (self , args : List [ Union [ Any , int ]]) -> _STEP_OUTPUT_TYPE :
181
192
"""The actual test step.
182
193
183
194
Args:
@@ -195,7 +206,7 @@ def test_step(self, args):
195
206
with self .precision_plugin .test_step_context (), self .training_type_plugin .test_step_context ():
196
207
return self .training_type_plugin .test_step (* args )
197
208
198
- def predict (self , args ) :
209
+ def predict (self , args : List [ Union [ Any , int ]]) -> _STEP_OUTPUT_TYPE :
199
210
"""The actual predict step.
200
211
201
212
Args:
@@ -213,23 +224,29 @@ def predict(self, args):
213
224
with self .precision_plugin .predict_context (), self .training_type_plugin .predict_context ():
214
225
return self .training_type_plugin .predict (* args )
215
226
216
- def training_step_end (self , output ):
227
+ def training_step_end (
228
+ self , output : _STEP_OUTPUT_TYPE
229
+ ) -> _STEP_OUTPUT_TYPE :
217
230
"""A hook to do something at the end of the training step
218
231
219
232
Args:
220
233
output: the output of the training step
221
234
"""
222
235
return self .training_type_plugin .training_step_end (output )
223
236
224
- def test_step_end (self , output ):
237
+ def test_step_end (
238
+ self , output : _STEP_OUTPUT_TYPE
239
+ ) -> _STEP_OUTPUT_TYPE :
225
240
"""A hook to do something at the end of the test step
226
241
227
242
Args:
228
243
output: the output of the test step
229
244
"""
230
245
return self .training_type_plugin .test_step_end (output )
231
246
232
- def validation_step_end (self , output ):
247
+ def validation_step_end (
248
+ self , output : _STEP_OUTPUT_TYPE
249
+ ) -> _STEP_OUTPUT_TYPE :
233
250
"""A hook to do something at the end of the validation step
234
251
235
252
Args:
@@ -243,8 +260,8 @@ def backward(
243
260
optimizer : Optimizer ,
244
261
optimizer_idx : int ,
245
262
should_accumulate : bool ,
246
- * args ,
247
- ** kwargs ,
263
+ * args : Any ,
264
+ ** kwargs : Any ,
248
265
) -> torch .Tensor :
249
266
"""Forwards backward-calls to the precision plugin.
250
267
@@ -262,7 +279,7 @@ def backward(
262
279
263
280
return output
264
281
265
- def optimizer_step (self , optimizer : Optimizer , opt_idx : int , lambda_closure : Callable , ** kwargs ) :
282
+ def optimizer_step (self , optimizer : Optimizer , opt_idx : int , lambda_closure : Callable , ** kwargs : Any ) -> None :
266
283
"""performs the actual optimizer step.
267
284
268
285
Args:
@@ -279,7 +296,9 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal
279
296
self .precision_plugin .post_optimizer_step (optimizer , opt_idx )
280
297
self .training_type_plugin .post_optimizer_step (optimizer , opt_idx , ** kwargs )
281
298
282
- def run_optimizer_step (self , optimizer : Optimizer , optimizer_idx : int , lambda_closure : Callable , ** kwargs ):
299
+ def run_optimizer_step (
300
+ self , optimizer : Optimizer , optimizer_idx : int , lambda_closure : Callable , ** kwargs : Any
301
+ ) -> None :
283
302
self .training_type_plugin .optimizer_step (optimizer , lambda_closure = lambda_closure , ** kwargs )
284
303
285
304
def optimizer_zero_grad (self , current_epoch : int , batch_idx : int , optimizer : Optimizer , opt_idx : int ) -> None :
@@ -292,7 +311,7 @@ def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> N
292
311
293
312
self .precision_plugin .clip_gradients (optimizer , clip_val )
294
313
295
- def on_train_epoch_end (self , outputs ) -> None :
314
+ def on_train_epoch_end (self , outputs : Sequence [ _STEP_OUTPUT_TYPE ] ) -> None :
296
315
"""Hook to do something on the end of an training epoch
297
316
298
317
Args:
@@ -304,7 +323,7 @@ def on_train_end(self) -> None:
304
323
"""Hook to do something at the end of the training"""
305
324
pass
306
325
307
- def setup_optimizers (self , trainer ) :
326
+ def setup_optimizers (self , trainer : 'Trainer' ) -> None :
308
327
"""creates optimizers and schedulers
309
328
310
329
Args:
@@ -327,7 +346,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn
327
346
"""
328
347
plugin .connect (model )
329
348
330
- def connect_precision_plugin (self , plugin : PrecisionPlugin ):
349
+ def connect_precision_plugin (self , plugin : PrecisionPlugin ) -> None :
331
350
"""Attaches the precision plugin to the accelerator"""
332
351
model , optimizers , schedulers = plugin .connect (self .model , self .optimizers , self .lr_schedulers )
333
352
self .model = model
@@ -351,26 +370,22 @@ def precision(self) -> int:
351
370
return self .precision_plugin .precision
352
371
353
372
@property
354
- def scaler (self ):
355
- if hasattr (self .precision_plugin , "scaler" ):
356
- return self .precision_plugin .scaler
373
+ def scaler (self ) -> Optional ['GradScaler' ]:
357
374
358
- return None
375
+ return getattr ( self . precision_plugin , 'scaler' , None )
359
376
360
377
@property
361
378
def rpc_enabled (self ) -> bool :
362
379
return self .training_type_plugin .rpc_enabled
363
380
364
- def optimizer_state (self , optimizer : Optimizer ) -> dict :
381
+ def optimizer_state (self , optimizer : Optimizer ) -> Dict [ str , torch . Tensor ] :
365
382
"""
366
383
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
367
384
plugins.
368
385
"""
369
- if self .training_type_plugin and hasattr (self .training_type_plugin , "optimizer_state" ):
370
- return self .training_type_plugin .optimizer_state (optimizer )
371
- return optimizer .state_dict ()
386
+ return getattr (self .training_type_plugin , 'optimizer_state' , lambda x : x .state_dict ())(optimizer )
372
387
373
- def on_save (self , checkpoint ) :
388
+ def on_save (self , checkpoint : Dict [ str , Union [ Any , torch . Tensor ]]) -> Dict [ str , Union [ Any , torch . Tensor ]] :
374
389
return checkpoint
375
390
376
391
def barrier (self , name : Optional [str ] = None ) -> None :
@@ -385,7 +400,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
385
400
"""
386
401
return self .training_type_plugin .broadcast (obj , src )
387
402
388
- def all_gather (self , tensor : Union [torch .Tensor ], group : Optional [Any ] = None , sync_grads : bool = False ):
403
+ def all_gather (
404
+ self , tensor : torch .Tensor , group : Optional [Any ] = None , sync_grads : bool = False
405
+ ) -> torch .Tensor :
389
406
"""
390
407
Function to gather a tensor from several distributed processes.
391
408
0 commit comments