Skip to content

Commit 2b2671a

Browse files
wanchaolpytorchmergebot
authored andcommitted
[dtensor] fix foreach_norm when ord is 2 (pytorch#130753)
as titled, fixed a case when passing ord as 2 (default value), the op dispatching does not receive the default value case We simply check if the args schema receiving a `ord` field or not Pull Request resolved: pytorch#130753 Approved by: https://github.com/awgu
1 parent a29052a commit 2b2671a

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

test/distributed/_tensor/test_math_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,25 @@ def test_shard0_svd(self):
437437
self.assertEqual(len(comm_counts), 1)
438438
self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 1)
439439

440+
@with_comms
441+
def test_foreach_norm(self):
442+
device_mesh = self.build_device_mesh()
443+
444+
grad0 = torch.randn(12, 8)
445+
grad1 = torch.randn(8, 8)
446+
447+
sharded_grad0 = distribute_tensor(grad0, device_mesh, [Shard(0)])
448+
sharded_grad1 = distribute_tensor(grad1, device_mesh, [Shard(0)])
449+
450+
# non-sharded op
451+
out = torch.ops.aten._foreach_norm([grad0, grad1], 2)
452+
453+
# sharded op
454+
sharded_out = torch.ops.aten._foreach_norm([sharded_grad0, sharded_grad1], 2)
455+
456+
for o, so in zip(out, sharded_out):
457+
self.assertEqual(so.full_tensor(), o)
458+
440459

441460
if __name__ == "__main__":
442461
run_tests()

torch/distributed/_tensor/ops/math_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrateg
397397
args_schema = op_schema.args_schema
398398
input_tuple_strategy = args_schema[0]
399399
assert isinstance(input_tuple_strategy, TupleStrategy)
400-
norm_type = args_schema[1]
400+
norm_type = args_schema[1] if len(args_schema) > 1 else 2
401401
assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
402402
output_tuple_strategy_childs: List[OpStrategy] = []
403403
for op_strategy in input_tuple_strategy.childs:

0 commit comments

Comments
 (0)