Skip to content

Commit f032135

Browse files
guilhermeleobaspytorchmergebot
authored andcommitted
Add batching rule for torch.scatter_reduce (pytorch#135547)
Fixes pytorch#134797 Pull Request resolved: pytorch#135547 Approved by: https://github.com/zou3519
1 parent 525bec8 commit f032135

File tree

3 files changed

+24
-34
lines changed

3 files changed

+24
-34
lines changed

aten/src/ATen/functorch/BatchRulesScatterOps.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,28 @@ std::tuple<Tensor, std::optional<int64_t>> scatter_reduce_batch_rule(
779779
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
780780
}
781781

782+
std::tuple<Tensor, std::optional<int64_t>> scatter_reduce_two_batch_rule(
783+
const Tensor& self, std::optional<int64_t> self_bdim,
784+
int64_t dim,
785+
const Tensor& index, std::optional<int64_t> index_bdim,
786+
const Tensor& src, std::optional<int64_t> src_bdim,
787+
const c10::string_view reduce,
788+
bool include_self) {
789+
return scatter_batch_rule(ATEN_FN2(scatter_reduce, two),
790+
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self);
791+
}
792+
793+
std::tuple<Tensor, std::optional<int64_t>> scatter_reduce__two_batch_rule(
794+
const Tensor& self, std::optional<int64_t> self_bdim,
795+
int64_t dim,
796+
const Tensor& index, std::optional<int64_t> index_bdim,
797+
const Tensor& src, std::optional<int64_t> src_bdim,
798+
const c10::string_view reduce,
799+
bool include_self) {
800+
return scatter_batch_rule(ATEN_FN2(scatter_reduce_, two),
801+
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self);
802+
}
803+
782804
std::tuple<Tensor, std::optional<int64_t>> scatter_value_reduce_batch_rule(
783805
const Tensor& self, std::optional<int64_t> self_bdim,
784806
int64_t dim,
@@ -1250,6 +1272,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
12501272
VMAP_SUPPORT(scatter_add, scatter_add_batch_rule);
12511273
VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule);
12521274
VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule);
1275+
VMAP_SUPPORT2(scatter_reduce, two, scatter_reduce_two_batch_rule);
1276+
VMAP_SUPPORT2(scatter_reduce_, two, scatter_reduce__two_batch_rule);
12531277
// as_strided_scatter does not work with the for-loop fallback today,
12541278
// because as_strided_scatter will return an output that matches
12551279
// the strides/storage_offset of its input.

test/functorch/test_ops.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,18 +1409,6 @@ def test_vmapjvpall(self, device, dtype, op):
14091409
xfail("nn.functional.soft_margin_loss", ""),
14101410
xfail("nn.functional.max_unpool1d", "grad"),
14111411
xfail("nn.functional.embedding", ""),
1412-
xfail(
1413-
"scatter_reduce", "sum"
1414-
), # aten::scatter_reduce.two hit the vmap fallback
1415-
xfail(
1416-
"scatter_reduce", "mean"
1417-
), # aten::scatter_reduce.two hit the vmap fallback
1418-
xfail(
1419-
"scatter_reduce", "amin"
1420-
), # aten::scatter_reduce.two hit the vmap fallback
1421-
xfail(
1422-
"scatter_reduce", "amax"
1423-
), # aten::scatter_reduce.two hit the vmap fallback
14241412
xfail("nn.functional.glu"),
14251413
xfail("nn.functional.bilinear"), # trilinear doesn't have batching rule
14261414
xfail("linalg.lu", ""),
@@ -1492,18 +1480,6 @@ def test():
14921480
xfail("nanquantile"),
14931481
xfail("ormqr"),
14941482
xfail("put"),
1495-
xfail(
1496-
"scatter_reduce", "sum"
1497-
), # aten::scatter_reduce.two hit the vmap fallback
1498-
xfail(
1499-
"scatter_reduce", "mean"
1500-
), # aten::scatter_reduce.two hit the vmap fallback
1501-
xfail(
1502-
"scatter_reduce", "amin"
1503-
), # aten::scatter_reduce.two hit the vmap fallback
1504-
xfail(
1505-
"scatter_reduce", "amax"
1506-
), # aten::scatter_reduce.two hit the vmap fallback
15071483
xfail("quantile"),
15081484
xfail("renorm"),
15091485
xfail("take"),
@@ -1530,7 +1506,6 @@ def test():
15301506
xfail("nn.functional.multi_margin_loss", ""),
15311507
xfail("nn.functional.multilabel_margin_loss", ""),
15321508
xfail("nn.functional.pdist", ""),
1533-
xfail("scatter_reduce", "prod"),
15341509
xfail("nn.functional.max_unpool1d", ""),
15351510
xfail("nn.functional.max_unpool3d", ""),
15361511
xfail("nn.functional.max_unpool3d", "grad"),

test/functorch/test_vmap.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4427,10 +4427,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
44274427
# TODO: implement batching rule
44284428
xfail("_batch_norm_with_update"),
44294429
xfail("histogram"),
4430-
xfail("scatter_reduce", "sum"),
4431-
xfail("scatter_reduce", "mean"),
4432-
xfail("scatter_reduce", "amax"),
4433-
xfail("scatter_reduce", "amin"),
44344430
# `index_put` OpInfo in pytorch/pytorch has
44354431
# masked index as input which is not supported
44364432
xfail("index_put", ""),
@@ -4502,20 +4498,15 @@ def test_vmap_exhaustive(self, device, dtype, op):
45024498
), # Batching rule not implemented for aten::narrow.Tensor
45034499
xfail("nn.functional.triplet_margin_loss", ""),
45044500
xfail("nn.functional.pdist", ""),
4505-
xfail("scatter_reduce", "sum"),
4506-
xfail("scatter_reduce", "amax"),
45074501
xfail("nn.functional.max_unpool1d", "grad"),
45084502
xfail("nn.functional.multi_margin_loss", ""),
4509-
xfail("scatter_reduce", "prod"),
45104503
xfail("nn.functional.multilabel_margin_loss", ""),
4511-
xfail("scatter_reduce", "amin"),
45124504
xfail("nn.functional.max_unpool3d", "grad"),
45134505
xfail("nn.functional.max_unpool2d", ""),
45144506
xfail("nn.functional.max_unpool2d", "grad"),
45154507
xfail("nn.functional.margin_ranking_loss", ""),
45164508
xfail("nn.functional.max_unpool1d", ""),
45174509
xfail("nn.functional.soft_margin_loss", ""),
4518-
xfail("scatter_reduce", "mean"),
45194510
xfail("nn.functional.max_unpool3d", ""),
45204511
xfail("linalg.ldl_solve", "", device_type="cpu"),
45214512
xfail("chalf", ""),

0 commit comments

Comments
 (0)