@@ -88,6 +88,19 @@ def forward(self, x):
88
88
return self .reduce (self .embed (x ))
89
89
90
90
91
+ class PartialScriptModel (LightningModule ):
92
+ """ A model which contains scripted layers. """
93
+
94
+ def __init__ (self ):
95
+ super ().__init__ ()
96
+ self .layer1 = torch .jit .script (nn .Linear (5 , 3 ))
97
+ self .layer2 = nn .Linear (3 , 2 )
98
+ self .example_input_array = torch .rand (2 , 5 )
99
+
100
+ def forward (self , x ):
101
+ return self .layer2 (self .layer1 (x ))
102
+
103
+
91
104
def test_invalid_weights_summmary ():
92
105
""" Test that invalid value for weights_summary raises an error. """
93
106
with pytest .raises (MisconfigurationException , match = '`mode` can be None, .* got temp' ):
@@ -97,11 +110,8 @@ def test_invalid_weights_summmary():
97
110
Trainer (weights_summary = 'temp' )
98
111
99
112
100
- @pytest .mark .parametrize (['mode' ], [
101
- pytest .param (ModelSummary .MODE_FULL ),
102
- pytest .param (ModelSummary .MODE_TOP ),
103
- ])
104
- def test_empty_model_summary_shapes (mode ):
113
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
114
+ def test_empty_model_summary_shapes (mode : ModelSummary ):
105
115
""" Test that the summary works for models that have no submodules. """
106
116
model = EmptyModule ()
107
117
summary = model .summarize (mode = mode )
@@ -110,10 +120,7 @@ def test_empty_model_summary_shapes(mode):
110
120
assert summary .param_nums == []
111
121
112
122
113
- @pytest .mark .parametrize (['mode' ], [
114
- pytest .param (ModelSummary .MODE_FULL ),
115
- pytest .param (ModelSummary .MODE_TOP ),
116
- ])
123
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
117
124
@pytest .mark .parametrize (['device' ], [
118
125
pytest .param (torch .device ('cpu' )),
119
126
pytest .param (torch .device ('cuda' , 0 )),
@@ -157,10 +164,7 @@ def test_mixed_dtype_model_summary():
157
164
]
158
165
159
166
160
- @pytest .mark .parametrize (['mode' ], [
161
- pytest .param (ModelSummary .MODE_FULL ),
162
- pytest .param (ModelSummary .MODE_TOP ),
163
- ])
167
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
164
168
def test_hooks_removed_after_summarize (mode ):
165
169
""" Test that all hooks were properly removed after summary, even ones that were not run. """
166
170
model = UnorderedModel ()
@@ -171,10 +175,7 @@ def test_hooks_removed_after_summarize(mode):
171
175
assert handle .id not in handle .hooks_dict_ref ()
172
176
173
177
174
- @pytest .mark .parametrize (['mode' ], [
175
- pytest .param (ModelSummary .MODE_FULL ),
176
- pytest .param (ModelSummary .MODE_TOP ),
177
- ])
178
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
178
179
def test_rnn_summary_shapes (mode ):
179
180
""" Test that the model summary works for RNNs. """
180
181
model = ParityModuleRNN ()
@@ -198,10 +199,7 @@ def test_rnn_summary_shapes(mode):
198
199
]
199
200
200
201
201
- @pytest .mark .parametrize (['mode' ], [
202
- pytest .param (ModelSummary .MODE_FULL ),
203
- pytest .param (ModelSummary .MODE_TOP ),
204
- ])
202
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
205
203
def test_summary_parameter_count (mode ):
206
204
""" Test that the summary counts the number of parameters in every submodule. """
207
205
model = UnorderedModel ()
@@ -215,10 +213,7 @@ def test_summary_parameter_count(mode):
215
213
]
216
214
217
215
218
- @pytest .mark .parametrize (['mode' ], [
219
- pytest .param (ModelSummary .MODE_FULL ),
220
- pytest .param (ModelSummary .MODE_TOP ),
221
- ])
216
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
222
217
def test_summary_layer_types (mode ):
223
218
""" Test that the summary displays the layer names correctly. """
224
219
model = UnorderedModel ()
@@ -232,10 +227,16 @@ def test_summary_layer_types(mode):
232
227
]
233
228
234
229
235
- @pytest .mark .parametrize (['mode' ], [
236
- pytest .param (ModelSummary .MODE_FULL ),
237
- pytest .param (ModelSummary .MODE_TOP ),
238
- ])
230
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
231
+ def test_summary_with_scripted_modules (mode ):
232
+ model = PartialScriptModel ()
233
+ summary = model .summarize (mode = mode )
234
+ assert summary .layer_types == ["RecursiveScriptModule" , "Linear" ]
235
+ assert summary .in_sizes == [UNKNOWN_SIZE , [2 , 3 ]]
236
+ assert summary .out_sizes == [UNKNOWN_SIZE , [2 , 2 ]]
237
+
238
+
239
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
239
240
@pytest .mark .parametrize (['example_input' , 'expected_size' ], [
240
241
pytest .param ([], UNKNOWN_SIZE ),
241
242
pytest .param ((1 , 2 , 3 ), [UNKNOWN_SIZE ] * 3 ),
@@ -269,21 +270,15 @@ def forward(self, *args, **kwargs):
269
270
assert summary .in_sizes == [expected_size ]
270
271
271
272
272
- @pytest .mark .parametrize (['mode' ], [
273
- pytest .param (ModelSummary .MODE_FULL ),
274
- pytest .param (ModelSummary .MODE_TOP ),
275
- ])
273
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
276
274
def test_model_size (mode ):
277
275
""" Test model size is calculated correctly. """
278
276
model = PreCalculatedModel ()
279
277
summary = model .summarize (mode = mode )
280
278
assert model .pre_calculated_model_size == summary .model_size
281
279
282
280
283
- @pytest .mark .parametrize (['mode' ], [
284
- pytest .param (ModelSummary .MODE_FULL ),
285
- pytest .param (ModelSummary .MODE_TOP ),
286
- ])
281
+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
287
282
def test_empty_model_size (mode ):
288
283
""" Test empty model size is zero. """
289
284
model = EmptyModule ()
@@ -293,23 +288,17 @@ def test_empty_model_size(mode):
293
288
294
289
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Test requires GPU." )
295
290
@pytest .mark .skipif (not _NATIVE_AMP_AVAILABLE , reason = "test requires native AMP." )
296
- @pytest .mark .parametrize (
297
- 'precision' , [
298
- pytest .param (16 , marks = pytest .mark .skip (reason = "no longer valid, because 16 can mean mixed precision" )),
299
- pytest .param (32 ),
300
- ]
301
- )
302
- def test_model_size_precision (monkeypatch , tmpdir , precision ):
291
+ def test_model_size_precision (tmpdir ):
303
292
""" Test model size for half and full precision. """
304
- model = PreCalculatedModel (precision )
293
+ model = PreCalculatedModel ()
305
294
306
295
# fit model
307
296
trainer = Trainer (
308
297
default_root_dir = tmpdir ,
309
298
gpus = 1 ,
310
299
max_steps = 1 ,
311
300
max_epochs = 1 ,
312
- precision = precision ,
301
+ precision = 32 ,
313
302
)
314
303
trainer .fit (model )
315
304
summary = model .summarize ()
0 commit comments