Skip to content

Commit 3b72bcc

Browse files
amogkamcarmoccas-rogkaushikb11
authored
Automatically set sync_batchnorm for training_type_plugin (#6536)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Roger Shieh <[email protected]> Co-authored-by: Kaushik Bokka <[email protected]>
1 parent 5780796 commit 3b72bcc

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,11 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
430430
if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None:
431431
training_type.num_nodes = self.num_nodes
432432

433+
# Automatically set sync_batchnorm if None.
434+
# Useful for custom plugins.
435+
if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None:
436+
training_type.sync_batchnorm = self.sync_batchnorm
437+
433438
return training_type
434439

435440
def select_accelerator(self) -> Accelerator:

tests/plugins/test_custom_plugin.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning import Trainer
15+
from pytorch_lightning.plugins import DDPPlugin
16+
from tests.helpers import BoringModel
17+
from tests.helpers.runif import RunIf
18+
19+
20+
class CustomParallelPlugin(DDPPlugin):
21+
22+
def __init__(self, **kwargs):
23+
super().__init__(**kwargs)
24+
# Set to None so it will be overwritten by the accelerator connector.
25+
self.sync_batchnorm = None
26+
27+
28+
@RunIf(skip_windows=True)
29+
def test_sync_batchnorm_set(tmpdir):
30+
"""Tests if sync_batchnorm is automatically set for custom plugin."""
31+
model = BoringModel()
32+
plugin = CustomParallelPlugin()
33+
assert plugin.sync_batchnorm is None
34+
trainer = Trainer(
35+
max_epochs=1,
36+
plugins=[plugin],
37+
default_root_dir=tmpdir,
38+
sync_batchnorm=True,
39+
)
40+
trainer.fit(model)
41+
assert plugin.sync_batchnorm is True

0 commit comments

Comments
 (0)