Skip to content

Commit 4b4c012

Browse files
wz337b-chu
andauthored
Enable planner to be used for loading sharded optimizer state dict (pytorch#112520)
Cherry-pick [pytorch#112259](pytorch#112259) Requested by MosaicML Comments from users: > without this, we can't do training resumption because the model gets loaded without the optimizer --------------------------------------------------------------------------------------------------------------------- This creates a more consistent interface for saving and loading sharded state dicts. A planner is able to be specified when saving a sharded optimizer state dict, but there is currently no planner support for loading one. This change does not affect the default behavior of the function. Co-authored-by: Brian <[email protected]>
1 parent 47ac502 commit 4b4c012

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch/distributed/checkpoint/optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DefaultLoadPlanner,
3434
)
3535
from torch.distributed._shard.api import _shard_tensor
36+
from torch.distributed.checkpoint.planner import LoadPlanner
3637

3738
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
3839
from torch.distributed.checkpoint.utils import (
@@ -212,6 +213,7 @@ def load_sharded_optimizer_state_dict(
212213
model_state_dict: STATE_DICT_TYPE,
213214
optimizer_key: str,
214215
storage_reader: dist_cp.StorageReader,
216+
planner: Optional[LoadPlanner] = None,
215217
) -> STATE_DICT_TYPE:
216218
"""
217219
Loads a state_dict in conjunction with FSDP sharded optimizer state.
@@ -337,7 +339,7 @@ def load_sharded_optimizer_state_dict(
337339
state_dict=state_dict,
338340
storage_reader=storage_reader,
339341
# FIXME the type of planner is wrong in load_state_dict
340-
planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else None,
342+
planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner,
341343
)
342344

343345
state_dict = unflatten_state_dict(state_dict, metadata.planner_data)

0 commit comments

Comments
 (0)