10
10
from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
11
11
from torch .distributed .fsdp .fully_sharded_data_parallel import StateDictType
12
12
from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
13
- from torch .testing ._internal .common_utils import run_tests
13
+ from torch .testing ._internal .common_utils import (
14
+ instantiate_parametrized_tests ,
15
+ parametrize ,
16
+ run_tests ,
17
+ )
14
18
15
19
from torch .testing ._internal .distributed ._tensor .common_dtensor import (
16
20
DTensorTestBase ,
@@ -53,8 +57,10 @@ def backend(self):
53
57
@with_comms
54
58
@skip_if_lt_x_gpu (4 )
55
59
@with_temp_dir
56
- def test_load_sharded_optimizer_state_dict (self ) -> None :
60
+ @parametrize ("pass_planner" , [True , False ])
61
+ def test_load_sharded_optimizer_state_dict (self , pass_planner ) -> None :
57
62
CHECKPOINT_DIR = self .temp_dir
63
+ planner = DCP .DefaultLoadPlanner () if pass_planner else None
58
64
59
65
model = self ._create_model ()
60
66
model = FSDP (model )
@@ -105,6 +111,7 @@ def test_load_sharded_optimizer_state_dict(self) -> None:
105
111
model_state_dict = state_dict ["model" ],
106
112
optimizer_key = "optim" ,
107
113
storage_reader = DCP .FileSystemReader (CHECKPOINT_DIR ),
114
+ planner = planner ,
108
115
)
109
116
flattened_osd = FSDP .optim_state_dict_to_load (
110
117
model_2 , optim_2 , optim_state ["optim" ]
@@ -126,5 +133,6 @@ def test_load_sharded_optimizer_state_dict(self) -> None:
126
133
self .assertEqual (state , state2 )
127
134
128
135
136
+ instantiate_parametrized_tests (FsdpOptimStateCheckpoint )
129
137
if __name__ == "__main__" :
130
138
run_tests ()
0 commit comments