@@ -621,7 +621,7 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
621
621
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> skip_ddp_prefix</ span > < span class ="p "> :</ span >
622
622
< span class ="n "> fqn_obj_names</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> curr_obj_name</ span > < span class ="p "> )</ span >
623
623
< span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> curr_obj</ span > < span class ="p "> ,</ span > < span class ="n "> FSDP</ span > < span class ="p "> ):</ span >
624
- < span class ="k "> if</ span > < span class ="n "> obj_names</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="o "> ==</ span > < span class ="n "> FLAT_PARAM</ span > < span class ="p "> :</ span >
624
+ < span class ="k "> if</ span > < span class ="n "> i </ span > < span class =" o " > < </ span > < span class =" nb " > len </ span > < span class =" p " > ( </ span > < span class =" n " > obj_names </ span > < span class =" p " > ) </ span > < span class =" o " > - </ span > < span class =" mi " > 1 </ span > < span class =" ow " > and </ span > < span class =" n " > obj_names</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="o "> ==</ span > < span class ="n "> FLAT_PARAM</ span > < span class ="p "> :</ span >
625
625
< span class ="n "> prefix</ span > < span class ="o "> =</ span > < span class ="s2 "> "."</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> (</ span > < span class ="n "> fqn_obj_names</ span > < span class ="p "> )</ span >
626
626
< span class ="n "> flat_param</ span > < span class ="o "> =</ span > < span class ="nb "> getattr</ span > < span class ="p "> (</ span > < span class ="n "> curr_obj</ span > < span class ="p "> ,</ span > < span class ="n "> FLAT_PARAM</ span > < span class ="p "> )</ span >
627
627
< span class ="k "> if</ span > < span class ="n "> prefix</ span > < span class ="p "> :</ span >
@@ -660,7 +660,7 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
660
660
< span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ],</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="n "> Set</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ],</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ]</ span >
661
661
< span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> {}</ span >
662
662
< span class ="n "> all_fqns</ span > < span class ="o "> =</ span > < span class ="nb "> set</ span > < span class ="p "> ()</ span >
663
- < span class ="k "> for</ span > < span class ="n "> name</ span > < span class ="p "> ,</ span > < span class ="n "> param</ span > < span class ="ow "> in</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> ():</ span >
663
+ < span class ="k "> for</ span > < span class ="n "> name</ span > < span class ="p "> ,</ span > < span class ="n "> param</ span > < span class ="ow "> in</ span > < span class ="n "> chain </ span > < span class =" p " > ( </ span > < span class =" n " > model</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> (), </ span > < span class =" n " > model </ span > < span class =" o " > . </ span > < span class =" n " > named_buffers </ span > < span class =" p " > () ):</ span >
664
664
< span class ="n "> fqns</ span > < span class ="o "> =</ span > < span class ="n "> _get_fqns</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> name</ span > < span class ="p "> )</ span >
665
665
< span class ="n "> fqn_param_mapping</ span > < span class ="p "> [</ span > < span class ="n "> param</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> fqns</ span >
666
666
< span class ="k "> for</ span > < span class ="n "> fqn</ span > < span class ="ow "> in</ span > < span class ="n "> fqns</ span > < span class ="p "> :</ span >
@@ -859,7 +859,7 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
859
859
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> info</ span > < span class ="o "> .</ span > < span class ="n "> handle_model</ span > < span class ="ow "> or</ span > < span class ="ow "> not</ span > < span class ="n "> state_dict</ span > < span class ="p "> :</ span >
860
860
< span class ="k "> return</ span > < span class ="n "> _IncompatibleKeys</ span > < span class ="p "> ({},</ span > < span class ="p "> {})</ span >
861
861
862
- < span class ="k "> for</ span > < span class ="n "> key</ span > < span class ="p "> ,</ span > < span class ="n "> _</ span > < span class ="ow "> in</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> ():</ span >
862
+ < span class ="k "> for</ span > < span class ="n "> key</ span > < span class ="p "> ,</ span > < span class ="n "> _</ span > < span class ="ow "> in</ span > < span class ="n "> chain </ span > < span class =" p " > ( </ span > < span class =" n " > model</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> (), </ span > < span class =" n " > model </ span > < span class =" o " > . </ span > < span class =" n " > named_buffers </ span > < span class =" p " > () ):</ span >
863
863
< span class ="n "> fqns</ span > < span class ="o "> =</ span > < span class ="n "> _get_fqns</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> key</ span > < span class ="p "> )</ span >
864
864
< span class ="n "> fqns_with_ddp_prefix</ span > < span class ="o "> =</ span > < span class ="n "> _get_fqns</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> key</ span > < span class ="p "> ,</ span > < span class ="n "> skip_ddp_prefix</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span >
865
865
< span class ="k "> for</ span > < span class ="n "> fqn</ span > < span class ="p "> ,</ span > < span class ="n "> fqn_with_ddp_prefix</ span > < span class ="ow "> in</ span > < span class ="nb "> zip</ span > < span class ="p "> (</ span > < span class ="n "> fqns</ span > < span class ="p "> ,</ span > < span class ="n "> fqns_with_ddp_prefix</ span > < span class ="p "> ):</ span >
@@ -1142,25 +1142,25 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
1142
1142
< span class ="sd "> optimizer parameter IDs to the canonical FQNs.</ span >
1143
1143
1144
1144
< span class ="sd "> Example:</ span >
1145
+ < span class ="sd "> >>> # xdoctest: +SKIP</ span >
1146
+ < span class ="sd "> >>> import torch</ span >
1147
+ < span class ="sd "> >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP</ span >
1148
+ < span class ="sd "> >>> from torch.nn.parallel import DistributedDataParallel as DDP</ span >
1149
+ < span class ="sd "> >>> from torch.distributed.checkpoint.state_dict import get_state_dict</ span >
1145
1150
1146
- < span class ="sd "> import torch</ span >
1147
- < span class ="sd "> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP</ span >
1148
- < span class ="sd "> from torch.nn.parallel import DistributedDataParallel as DDP</ span >
1149
- < span class ="sd "> from torch.distributed.checkpoint.state_dict import get_state_dict</ span >
1150
-
1151
- < span class ="sd "> fsdp_model = FSDP(copy.deepcopy(model))</ span >
1152
- < span class ="sd "> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)</ span >
1153
- < span class ="sd "> ddp_model = DDP(copy.deepcopy(model))</ span >
1154
- < span class ="sd "> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)</ span >
1151
+ < span class ="sd "> >>> fsdp_model = FSDP(copy.deepcopy(model))</ span >
1152
+ < span class ="sd "> >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)</ span >
1153
+ < span class ="sd "> >>> ddp_model = DDP(copy.deepcopy(model))</ span >
1154
+ < span class ="sd "> >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)</ span >
1155
1155
1156
1156
1157
- < span class ="sd "> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)</ span >
1158
- < span class ="sd "> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)</ span >
1157
+ < span class ="sd "> >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)</ span >
1158
+ < span class ="sd "> >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)</ span >
1159
1159
1160
- < span class ="sd "> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),</ span >
1161
- < span class ="sd "> # the asserts will fail.</ span >
1162
- < span class ="sd "> assert ddp_state_dict == fsdp_state_dict</ span >
1163
- < span class ="sd "> assert ddp_optim_state == fsdp_optim_state_dict</ span >
1160
+ < span class ="sd "> >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),</ span >
1161
+ < span class ="sd "> >>> # the asserts will fail.</ span >
1162
+ < span class ="sd "> >>> assert ddp_state_dict == fsdp_state_dict</ span >
1163
+ < span class ="sd "> >>> assert ddp_optim_state == fsdp_optim_state_dict</ span >
1164
1164
1165
1165
1166
1166
< span class ="sd "> Args:</ span >
@@ -1175,6 +1175,8 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
1175
1175
1176
1176
< span class ="sd "> Returns:</ span >
1177
1177
< span class ="sd "> ``Tuple`` that contain model state_dict and optimizer state_dict.</ span >
1178
+
1179
+ < span class ="sd "> :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]</ span >
1178
1180
< span class ="sd "> """</ span >
1179
1181
1180
1182
< span class ="k "> with</ span > < span class ="n "> gc_context</ span > < span class ="p "> ():</ span >
0 commit comments