Skip to content

Commit 50aedad

Browse files
awaelchlilexierule
authored andcommitted
Handle torch.jit scripted modules in layer summary (#6511)
(cherry picked from commit 02fa32b)
1 parent 9fc733f commit 50aedad

File tree

3 files changed

+46
-50
lines changed

3 files changed

+46
-50
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
143143
- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))
144144

145145

146+
- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511))
147+
148+
146149
## [1.2.3] - 2021-03-09
147150

148151
### Fixed

pytorch_lightning/core/memory.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import shutil
1717
import subprocess
1818
from collections import OrderedDict
19-
from typing import Any, Dict, List, Tuple, Union
19+
from typing import Any, Dict, List, Optional, Tuple, Union
2020

2121
import numpy as np
2222
import torch
@@ -71,14 +71,15 @@ def __init__(self, module: nn.Module):
7171
def __del__(self):
7272
self.detach_hook()
7373

74-
def _register_hook(self) -> RemovableHandle:
74+
def _register_hook(self) -> Optional[RemovableHandle]:
7575
"""
7676
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
7777
If the hook is called, it will remove itself from the from the module, meaning that
7878
recursive models will only record their input- and output shapes once.
79+
Registering hooks on :class:`~torch.jit.ScriptModule` is not supported.
7980
8081
Return:
81-
A handle for the installed hook.
82+
A handle for the installed hook, or ``None`` if registering the hook is not possible.
8283
"""
8384

8485
def hook(module, inp, out):
@@ -88,7 +89,10 @@ def hook(module, inp, out):
8889
self._out_size = parse_batch_shape(out)
8990
self._hook_handle.remove()
9091

91-
return self._module.register_forward_hook(hook)
92+
handle = None
93+
if not isinstance(self._module, torch.jit.ScriptModule):
94+
handle = self._module.register_forward_hook(hook)
95+
return handle
9296

9397
def detach_hook(self):
9498
"""

tests/core/test_memory.py

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def forward(self, x):
8888
return self.reduce(self.embed(x))
8989

9090

91+
class PartialScriptModel(LightningModule):
92+
""" A model which contains scripted layers. """
93+
94+
def __init__(self):
95+
super().__init__()
96+
self.layer1 = torch.jit.script(nn.Linear(5, 3))
97+
self.layer2 = nn.Linear(3, 2)
98+
self.example_input_array = torch.rand(2, 5)
99+
100+
def forward(self, x):
101+
return self.layer2(self.layer1(x))
102+
103+
91104
def test_invalid_weights_summmary():
92105
""" Test that invalid value for weights_summary raises an error. """
93106
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
@@ -97,11 +110,8 @@ def test_invalid_weights_summmary():
97110
Trainer(weights_summary='temp')
98111

99112

100-
@pytest.mark.parametrize(['mode'], [
101-
pytest.param(ModelSummary.MODE_FULL),
102-
pytest.param(ModelSummary.MODE_TOP),
103-
])
104-
def test_empty_model_summary_shapes(mode):
113+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
114+
def test_empty_model_summary_shapes(mode: ModelSummary):
105115
""" Test that the summary works for models that have no submodules. """
106116
model = EmptyModule()
107117
summary = model.summarize(mode=mode)
@@ -110,10 +120,7 @@ def test_empty_model_summary_shapes(mode):
110120
assert summary.param_nums == []
111121

112122

113-
@pytest.mark.parametrize(['mode'], [
114-
pytest.param(ModelSummary.MODE_FULL),
115-
pytest.param(ModelSummary.MODE_TOP),
116-
])
123+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
117124
@pytest.mark.parametrize(['device'], [
118125
pytest.param(torch.device('cpu')),
119126
pytest.param(torch.device('cuda', 0)),
@@ -157,10 +164,7 @@ def test_mixed_dtype_model_summary():
157164
]
158165

159166

160-
@pytest.mark.parametrize(['mode'], [
161-
pytest.param(ModelSummary.MODE_FULL),
162-
pytest.param(ModelSummary.MODE_TOP),
163-
])
167+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
164168
def test_hooks_removed_after_summarize(mode):
165169
""" Test that all hooks were properly removed after summary, even ones that were not run. """
166170
model = UnorderedModel()
@@ -171,10 +175,7 @@ def test_hooks_removed_after_summarize(mode):
171175
assert handle.id not in handle.hooks_dict_ref()
172176

173177

174-
@pytest.mark.parametrize(['mode'], [
175-
pytest.param(ModelSummary.MODE_FULL),
176-
pytest.param(ModelSummary.MODE_TOP),
177-
])
178+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
178179
def test_rnn_summary_shapes(mode):
179180
""" Test that the model summary works for RNNs. """
180181
model = ParityModuleRNN()
@@ -198,10 +199,7 @@ def test_rnn_summary_shapes(mode):
198199
]
199200

200201

201-
@pytest.mark.parametrize(['mode'], [
202-
pytest.param(ModelSummary.MODE_FULL),
203-
pytest.param(ModelSummary.MODE_TOP),
204-
])
202+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
205203
def test_summary_parameter_count(mode):
206204
""" Test that the summary counts the number of parameters in every submodule. """
207205
model = UnorderedModel()
@@ -215,10 +213,7 @@ def test_summary_parameter_count(mode):
215213
]
216214

217215

218-
@pytest.mark.parametrize(['mode'], [
219-
pytest.param(ModelSummary.MODE_FULL),
220-
pytest.param(ModelSummary.MODE_TOP),
221-
])
216+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
222217
def test_summary_layer_types(mode):
223218
""" Test that the summary displays the layer names correctly. """
224219
model = UnorderedModel()
@@ -232,10 +227,16 @@ def test_summary_layer_types(mode):
232227
]
233228

234229

235-
@pytest.mark.parametrize(['mode'], [
236-
pytest.param(ModelSummary.MODE_FULL),
237-
pytest.param(ModelSummary.MODE_TOP),
238-
])
230+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
231+
def test_summary_with_scripted_modules(mode):
232+
model = PartialScriptModel()
233+
summary = model.summarize(mode=mode)
234+
assert summary.layer_types == ["RecursiveScriptModule", "Linear"]
235+
assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]]
236+
assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]]
237+
238+
239+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
239240
@pytest.mark.parametrize(['example_input', 'expected_size'], [
240241
pytest.param([], UNKNOWN_SIZE),
241242
pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3),
@@ -269,21 +270,15 @@ def forward(self, *args, **kwargs):
269270
assert summary.in_sizes == [expected_size]
270271

271272

272-
@pytest.mark.parametrize(['mode'], [
273-
pytest.param(ModelSummary.MODE_FULL),
274-
pytest.param(ModelSummary.MODE_TOP),
275-
])
273+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
276274
def test_model_size(mode):
277275
""" Test model size is calculated correctly. """
278276
model = PreCalculatedModel()
279277
summary = model.summarize(mode=mode)
280278
assert model.pre_calculated_model_size == summary.model_size
281279

282280

283-
@pytest.mark.parametrize(['mode'], [
284-
pytest.param(ModelSummary.MODE_FULL),
285-
pytest.param(ModelSummary.MODE_TOP),
286-
])
281+
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
287282
def test_empty_model_size(mode):
288283
""" Test empty model size is zero. """
289284
model = EmptyModule()
@@ -293,23 +288,17 @@ def test_empty_model_size(mode):
293288

294289
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
295290
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
296-
@pytest.mark.parametrize(
297-
'precision', [
298-
pytest.param(16, marks=pytest.mark.skip(reason="no longer valid, because 16 can mean mixed precision")),
299-
pytest.param(32),
300-
]
301-
)
302-
def test_model_size_precision(monkeypatch, tmpdir, precision):
291+
def test_model_size_precision(tmpdir):
303292
""" Test model size for half and full precision. """
304-
model = PreCalculatedModel(precision)
293+
model = PreCalculatedModel()
305294

306295
# fit model
307296
trainer = Trainer(
308297
default_root_dir=tmpdir,
309298
gpus=1,
310299
max_steps=1,
311300
max_epochs=1,
312-
precision=precision,
301+
precision=32,
313302
)
314303
trainer.fit(model)
315304
summary = model.summarize()

0 commit comments

Comments
 (0)