8
8
9
9
from pytorch_lightning import Trainer
10
10
from pytorch_lightning .plugins import DeepSpeedPlugin , DeepSpeedPrecisionPlugin
11
+ from pytorch_lightning .plugins .training_type .deepspeed import LightningDeepSpeedModule
11
12
from pytorch_lightning .utilities import _APEX_AVAILABLE , _DEEPSPEED_AVAILABLE , _NATIVE_AMP_AVAILABLE
12
13
from pytorch_lightning .utilities .exceptions import MisconfigurationException
13
14
from tests .helpers .boring_model import BoringModel
14
15
15
16
17
+ def test_deepspeed_lightning_module (tmpdir ):
18
+ """
19
+ Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly.
20
+ """
21
+
22
+ model = BoringModel ()
23
+ module = LightningDeepSpeedModule (model , precision = 16 )
24
+
25
+ module .half ()
26
+ assert module .dtype == torch .half
27
+ assert model .dtype == torch .half
28
+
29
+ module .to (torch .double )
30
+ assert module .dtype == torch .double
31
+ assert model .dtype == torch .double
32
+
33
+
34
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "requires GPU machine" )
35
+ def test_deepspeed_lightning_module_precision (tmpdir ):
36
+ """
37
+ Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16.
38
+ """
39
+
40
+ model = BoringModel ()
41
+ module = LightningDeepSpeedModule (model , precision = 16 )
42
+
43
+ module .cuda ().half ()
44
+ assert module .dtype == torch .half
45
+ assert model .dtype == torch .half
46
+
47
+ x = torch .randn ((1 , 32 ), dtype = torch .float ).cuda ()
48
+ out = module (x )
49
+
50
+ assert out .dtype == torch .half
51
+
52
+ module .to (torch .double )
53
+ assert module .dtype == torch .double
54
+ assert model .dtype == torch .double
55
+
56
+
16
57
@pytest .fixture
17
58
def deepspeed_config ():
18
59
return {
@@ -34,6 +75,11 @@ def deepspeed_config():
34
75
}
35
76
36
77
78
+ @pytest .fixture
79
+ def deepspeed_zero_config (deepspeed_config ):
80
+ return {** deepspeed_config , 'zero_allow_untested_optimizer' : True , 'zero_optimization' : {'stage' : 2 }}
81
+
82
+
37
83
@pytest .mark .skipif (not _DEEPSPEED_AVAILABLE , reason = "DeepSpeed not available." )
38
84
def test_deepspeed_plugin_string (tmpdir ):
39
85
"""
@@ -179,12 +225,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args
179
225
return loss .backward ()
180
226
181
227
model = TestModel ()
182
- trainer = Trainer (
183
- fast_dev_run = True ,
184
- default_root_dir = tmpdir ,
185
- plugins = DeepSpeedPlugin (zero_optimization = False ),
186
- gpus = 1 ,
187
- )
228
+ trainer = Trainer (fast_dev_run = True , default_root_dir = tmpdir , plugins = DeepSpeedPlugin (), gpus = 1 , precision = 16 )
188
229
with pytest .warns (UserWarning , match = 'Overridden backward hook in the LightningModule will be ignored' ):
189
230
trainer .fit (model )
190
231
@@ -203,17 +244,21 @@ def test_deepspeed_run_configure_optimizers(tmpdir):
203
244
class TestModel (BoringModel ):
204
245
205
246
def on_train_start (self ) -> None :
206
- assert isinstance (self .trainer .optimizers [0 ], torch .optim .SGD )
247
+ from deepspeed .runtime .zero .stage2 import FP16_DeepSpeedZeroOptimizer
248
+
249
+ assert isinstance (self .trainer .optimizers [0 ], FP16_DeepSpeedZeroOptimizer )
250
+ assert isinstance (self .trainer .optimizers [0 ].optimizer , torch .optim .SGD )
207
251
assert self .trainer .lr_schedulers == [] # DeepSpeed manages LR scheduler internally
208
252
# Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler
209
253
assert isinstance (self .trainer .model .lr_scheduler , torch .optim .lr_scheduler .StepLR )
210
254
211
255
model = TestModel ()
212
256
trainer = Trainer (
213
- plugins = DeepSpeedPlugin (zero_optimization = False ),
257
+ plugins = DeepSpeedPlugin (), # disable ZeRO so our optimizers are not wrapped
214
258
default_root_dir = tmpdir ,
215
259
gpus = 1 ,
216
260
fast_dev_run = True ,
261
+ precision = 16
217
262
)
218
263
219
264
trainer .fit (model )
@@ -226,7 +271,7 @@ def on_train_start(self) -> None:
226
271
@pytest .mark .skipif (
227
272
not os .getenv ("PL_RUNNING_SPECIAL_TESTS" , '0' ) == '1' , reason = "test should be run outside of pytest"
228
273
)
229
- def test_deepspeed_config (tmpdir , deepspeed_config ):
274
+ def test_deepspeed_config (tmpdir , deepspeed_zero_config ):
230
275
"""
231
276
Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
232
277
and saves the model weights to load correctly.
@@ -235,18 +280,22 @@ def test_deepspeed_config(tmpdir, deepspeed_config):
235
280
class TestModel (BoringModel ):
236
281
237
282
def on_train_start (self ) -> None :
238
- import deepspeed
239
- assert isinstance (self .trainer .optimizers [0 ], torch .optim .SGD )
283
+ from deepspeed .runtime .lr_schedules import WarmupLR
284
+ from deepspeed .runtime .zero .stage2 import FP16_DeepSpeedZeroOptimizer
285
+
286
+ assert isinstance (self .trainer .optimizers [0 ], FP16_DeepSpeedZeroOptimizer )
287
+ assert isinstance (self .trainer .optimizers [0 ].optimizer , torch .optim .SGD )
240
288
assert self .trainer .lr_schedulers == [] # DeepSpeed manages LR scheduler internally
241
- assert isinstance ( self . trainer . model . optimizer , torch . optim . SGD )
242
- assert isinstance (self .trainer .model .lr_scheduler , deepspeed . runtime . lr_schedules . WarmupLR )
289
+ # Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler
290
+ assert isinstance (self .trainer .model .lr_scheduler , WarmupLR )
243
291
244
292
model = TestModel ()
245
293
trainer = Trainer (
246
- plugins = [DeepSpeedPlugin (config = deepspeed_config )],
294
+ plugins = [DeepSpeedPlugin (config = deepspeed_zero_config )],
247
295
default_root_dir = tmpdir ,
248
296
gpus = 1 ,
249
297
fast_dev_run = True ,
298
+ precision = 16
250
299
)
251
300
252
301
trainer .fit (model )
@@ -267,7 +316,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config):
267
316
"""
268
317
model = BoringModel ()
269
318
trainer = Trainer (
270
- plugins = [DeepSpeedPlugin (zero_optimization = False )],
319
+ plugins = [DeepSpeedPlugin ()],
271
320
default_root_dir = tmpdir ,
272
321
gpus = 2 ,
273
322
fast_dev_run = True ,
@@ -285,8 +334,9 @@ def _assert_save_model_is_equal(model, tmpdir, trainer):
285
334
# carry out the check only on rank 0
286
335
if trainer .global_rank == 0 :
287
336
saved_model = BoringModel .load_from_checkpoint (checkpoint_path )
288
- saved_model = saved_model .float ()
289
- model = model .float ().cpu ()
337
+ if model .dtype == torch .half :
338
+ saved_model = saved_model .half () # model is loaded in float32 as default, move it to float16
339
+ model = model .cpu ()
290
340
# Assert model parameters are identical after loading
291
341
for orig_param , trained_model_param in zip (model .parameters (), saved_model .parameters ()):
292
342
assert torch .equal (orig_param , trained_model_param )
0 commit comments