Skip to content

Commit 00620ab

Browse files
[mlir][SCF] Allow canonicalization of zero-trip count scf.forall with empty mapping. (#105793)
Current folding of one-trip count loop does not kick in with an empty mapping. Enable this for empty mapping. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent ceb587a commit 00620ab

File tree

3 files changed

+34
-27
lines changed

3 files changed

+34
-27
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,7 +1700,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
17001700
LogicalResult matchAndRewrite(ForallOp op,
17011701
PatternRewriter &rewriter) const override {
17021702
// Do not fold dimensions if they are mapped to processing units.
1703-
if (op.getMapping().has_value())
1703+
if (op.getMapping().has_value() && !op.getMapping()->empty())
17041704
return failure();
17051705
Location loc = op.getLoc();
17061706

@@ -1729,18 +1729,19 @@ struct ForallOpSingleOrZeroIterationDimsFolder
17291729
newMixedUpperBounds.push_back(ub);
17301730
newMixedSteps.push_back(step);
17311731
}
1732-
// Exit if none of the loop dimensions perform a single iteration.
1733-
if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1734-
return rewriter.notifyMatchFailure(
1735-
op, "no dimensions have 0 or 1 iterations");
1736-
}
17371732

17381733
// All of the loop dimensions perform a single iteration. Inline loop body.
17391734
if (newMixedLowerBounds.empty()) {
17401735
promote(rewriter, op);
17411736
return success();
17421737
}
17431738

1739+
// Exit if none of the loop dimensions perform a single iteration.
1740+
if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1741+
return rewriter.notifyMatchFailure(
1742+
op, "no dimensions have 0 or 1 iterations");
1743+
}
1744+
17441745
// Replace the loop by a lower-dimensional loop.
17451746
ForallOp newOp;
17461747
newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,33 @@ func.func @do_not_inline_distributed_forall_loop(
16351635

16361636
// -----
16371637

1638+
func.func @inline_empty_loop_with_empty_mapping(
1639+
%in: tensor<16xf32>) -> tensor<16xf32> {
1640+
%cst = arith.constant 0.000000e+00 : f32
1641+
%0 = tensor.empty() : tensor<16xf32>
1642+
%1 = scf.forall () in () shared_outs (%out_ = %0) -> (tensor<16xf32>) {
1643+
%slice = tensor.extract_slice %out_[0] [16] [1]
1644+
: tensor<16xf32> to tensor<16xf32>
1645+
%generic = linalg.generic {
1646+
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
1647+
iterator_types = ["parallel"]}
1648+
ins(%slice : tensor<16xf32>) outs(%0 : tensor<16xf32>) {
1649+
^bb0(%b0 : f32, %b1 : f32):
1650+
%2 = arith.addf %b0, %b0 : f32
1651+
linalg.yield %2 : f32
1652+
} -> tensor<16xf32>
1653+
scf.forall.in_parallel {
1654+
tensor.parallel_insert_slice %generic into %out_[0] [16] [1]
1655+
: tensor<16xf32> into tensor<16xf32>
1656+
}
1657+
}{ mapping = [] }
1658+
return %1 : tensor<16xf32>
1659+
}
1660+
// CHECK-LABEL: func @inline_empty_loop_with_empty_mapping
1661+
// CHECK-NOT: scf.forall
1662+
1663+
// -----
1664+
16381665
func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
16391666
%c8 = arith.constant 8 : index
16401667
%c0 = arith.constant 0 : index

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,27 +2076,6 @@ func.func @canonicalize_parallel_insert_slice_indices(
20762076

20772077
// -----
20782078

2079-
// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice(
2080-
// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
2081-
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>)
2082-
func.func @dont_fold_parallel_insert_slice(
2083-
%arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32>
2084-
{
2085-
%c0 = arith.constant 0 : index
2086-
%c1 = arith.constant 1 : index
2087-
// CHECK: scf.forall () in () shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<1x5xf32>) {
2088-
// CHECK-NEXT: scf.forall.in_parallel {
2089-
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
2090-
%2 = scf.forall () in () shared_outs(%o = %arg1) -> (tensor<1x5xf32>) {
2091-
scf.forall.in_parallel {
2092-
tensor.parallel_insert_slice %arg0 into %o[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
2093-
}
2094-
}
2095-
return %2 : tensor<1x5xf32>
2096-
}
2097-
2098-
// -----
2099-
21002079
// CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice
21012080
// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
21022081
func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {

0 commit comments

Comments
 (0)