Skip to content

Commit 33106b7

Browse files
authored
[DCP] Add test for planner option for load_sharded_optimizer_state_dict (pytorch#112930)
Add test for a user submitted PR: pytorch#112259 Cherry-pick of pytorch#112891 into `release/2.1` branch
1 parent 4b4c012 commit 33106b7

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

test/distributed/checkpoint/test_fsdp_optim_state.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1111
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
1212
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
13-
from torch.testing._internal.common_utils import run_tests
13+
from torch.testing._internal.common_utils import (
14+
instantiate_parametrized_tests,
15+
parametrize,
16+
run_tests,
17+
)
1418

1519
from torch.testing._internal.distributed._tensor.common_dtensor import (
1620
DTensorTestBase,
@@ -53,8 +57,10 @@ def backend(self):
5357
@with_comms
5458
@skip_if_lt_x_gpu(4)
5559
@with_temp_dir
56-
def test_load_sharded_optimizer_state_dict(self) -> None:
60+
@parametrize("pass_planner", [True, False])
61+
def test_load_sharded_optimizer_state_dict(self, pass_planner) -> None:
5762
CHECKPOINT_DIR = self.temp_dir
63+
planner = DCP.DefaultLoadPlanner() if pass_planner else None
5864

5965
model = self._create_model()
6066
model = FSDP(model)
@@ -105,6 +111,7 @@ def test_load_sharded_optimizer_state_dict(self) -> None:
105111
model_state_dict=state_dict["model"],
106112
optimizer_key="optim",
107113
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
114+
planner=planner,
108115
)
109116
flattened_osd = FSDP.optim_state_dict_to_load(
110117
model_2, optim_2, optim_state["optim"]
@@ -126,5 +133,6 @@ def test_load_sharded_optimizer_state_dict(self) -> None:
126133
self.assertEqual(state, state2)
127134

128135

136+
instantiate_parametrized_tests(FsdpOptimStateCheckpoint)
129137
if __name__ == "__main__":
130138
run_tests()

0 commit comments

Comments
 (0)