Skip to content

Commit 936998d

Browse files
wz337facebook-github-bot
authored andcommitted
Remove composable API's fully_shard from torchrec test (#2549)
Summary: Pull Request resolved: #2549 The fully_shard name is now used by FSDP2 (torch.distributed._composable.fsdp.fully_shard) and the Composable API's fully_shard (torch.distributed._composable.fully_shard) is being deprecated. Therefore, we want to remove torch.distributed._composable.fully_shard from torchrec as well. Deprecation message from PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fully_shard.py#L41-L48 Reviewed By: fegin, iamzainhuda Differential Revision: D65702643 fbshipit-source-id: 22494e5a6a201a32c86f571c7604383a2e3b3a02
1 parent bcd1af9 commit 936998d

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

torchrec/distributed/composable/tests/test_fsdp.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import torch
1616
from torch import nn
17-
from torch.distributed._composable import fully_shard
17+
18+
# from torch.distributed._composable import fully_shard
1819
from torch.distributed._shard.sharded_tensor import ShardedTensor
1920
from torch.distributed._tensor import DTensor
2021

@@ -24,6 +25,7 @@
2425
load_state_dict,
2526
save_state_dict,
2627
)
28+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2729
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
2830
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
2931
from torch.distributed.optim import (
@@ -91,15 +93,15 @@ def _run( # noqa
9193
device=ctx.device,
9294
plan=row_wise(),
9395
)
94-
m.dense = fully_shard(
96+
m.dense = FSDP( # pyre-ignore
9597
m.dense,
98+
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
9699
device_id=ctx.device.index,
97-
policy=ModuleWrapPolicy({nn.Linear}),
98100
)
99-
m.over = fully_shard(
101+
m.over = FSDP(
100102
m.over,
103+
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
101104
device_id=ctx.device.index,
102-
policy=ModuleWrapPolicy({nn.Linear}),
103105
)
104106

105107
dense_opt = KeyedOptimizerWrapper(

0 commit comments

Comments
 (0)