27
27
from tests .helpers import BoringModel
28
28
29
29
30
+ class AMPTestModel (BoringModel ):
31
+
32
+ def training_step (self , batch , batch_idx ):
33
+ assert torch .is_autocast_enabled ()
34
+ output = self (batch )
35
+ assert output .dtype == torch .float16
36
+ loss = self .loss (batch , output )
37
+ return {"loss" : loss }
38
+
39
+
30
40
@pytest .mark .skip (reason = 'dp + amp not supported currently' ) # TODO
31
41
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "test requires GPU machine" )
32
42
def test_amp_single_gpu_dp (tmpdir ):
@@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir):
41
51
precision = 16 ,
42
52
)
43
53
44
- model = BoringModel ()
54
+ model = AMPTestModel ()
45
55
# tutils.run_model_test(trainer_options, model)
46
56
trainer .fit (model )
47
57
@@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
60
70
precision = 16 ,
61
71
)
62
72
63
- model = BoringModel ()
73
+ model = AMPTestModel ()
64
74
# tutils.run_model_test(trainer_options, model)
65
75
trainer .fit (model )
66
-
67
76
assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
68
77
69
78
@@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir):
81
90
precision = 16 ,
82
91
)
83
92
84
- model = BoringModel ()
93
+ model = AMPTestModel ()
85
94
# tutils.run_model_test(trainer_options, model)
86
95
trainer .fit (model )
87
96
@@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
100
109
precision = 16 ,
101
110
)
102
111
103
- model = BoringModel ()
112
+ model = AMPTestModel ()
104
113
# tutils.run_model_test(trainer_options, model)
105
114
trainer .fit (model )
106
-
107
115
assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
108
116
109
117
@@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
122
130
# simulate setting slurm flags
123
131
tutils .set_random_master_port ()
124
132
125
- model = BoringModel ()
133
+ model = AMPTestModel ()
126
134
127
135
# exp file to get meta
128
136
logger = tutils .get_default_logger (tmpdir )
0 commit comments