Skip to content

Commit 8915715

Browse files
[mlir][SCF] Add folding for IndexSwitchOp
1 parent 18669b1 commit 8915715

File tree

4 files changed

+70
-7
lines changed

4 files changed

+70
-7
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
11261126
Block &getCaseBlock(unsigned idx);
11271127
}];
11281128

1129+
let hasFolder = 1;
11291130
let hasVerifier = 1;
11301131
}
11311132

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4166,6 +4166,35 @@ void IndexSwitchOp::getRegionInvocationBounds(
41664166
bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
41674167
}
41684168

4169+
LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
4170+
SmallVectorImpl<OpFoldResult> &results) {
4171+
std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
4172+
if (!maybeCst.has_value())
4173+
return failure();
4174+
int64_t cst = *maybeCst;
4175+
int64_t caseIdx, e = getNumCases();
4176+
for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4177+
if (cst == getCases()[caseIdx])
4178+
break;
4179+
}
4180+
4181+
Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
4182+
: getDefaultRegion();
4183+
Block &source = r.front();
4184+
results.assign(source.getTerminator()->getOperands().begin(),
4185+
source.getTerminator()->getOperands().end());
4186+
4187+
Block *pDestination = (*this)->getBlock();
4188+
if (!pDestination)
4189+
return failure();
4190+
Block::iterator insertionPoint = (*this)->getIterator();
4191+
pDestination->getOperations().splice(insertionPoint, source.getOperations(),
4192+
source.getOperations().begin(),
4193+
std::prev(source.getOperations().end()));
4194+
4195+
return success();
4196+
}
4197+
41694198
//===----------------------------------------------------------------------===//
41704199
// TableGen'd op method definitions
41714200
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,3 +1756,36 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
17561756
// CHECK: parallel_insert_slice
17571757
// CHECK-SAME: : tensor<1xi32> into tensor<2xi32>
17581758
// CHECK: tensor.cast
1759+
1760+
// -----
1761+
1762+
func.func @index_switch_fold() -> (f32, f32) {
1763+
%switch_cst = arith.constant 1: index
1764+
%0 = scf.index_switch %switch_cst -> f32
1765+
case 1 {
1766+
%y = arith.constant 1.0 : f32
1767+
scf.yield %y : f32
1768+
}
1769+
default {
1770+
%y = arith.constant 42.0 : f32
1771+
scf.yield %y : f32
1772+
}
1773+
1774+
%switch_cst_2 = arith.constant 2: index
1775+
%1 = scf.index_switch %switch_cst_2 -> f32
1776+
case 0 {
1777+
%y = arith.constant 0.0 : f32
1778+
scf.yield %y : f32
1779+
}
1780+
default {
1781+
%y = arith.constant 42.0 : f32
1782+
scf.yield %y : f32
1783+
}
1784+
1785+
return %0, %1 : f32, f32
1786+
}
1787+
1788+
// CHECK-LABEL: func.func @index_switch_fold()
1789+
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1.000000e+00 : f32
1790+
// CHECK-NEXT: %[[c42:.*]] = arith.constant 4.200000e+01 : f32
1791+
// CHECK-NEXT: return %[[c1]], %[[c42]] : f32, f32

mlir/test/Dialect/SCF/for-loop-canonicalization.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -394,14 +394,18 @@ func.func @regression_multiplication_with_sym(%A : memref<i64>) {
394394

395395
// -----
396396

397+
397398
// Make sure min is transformed into zero.
398399

399-
// CHECK: %[[ZERO:.+]] = arith.constant 0 : index
400-
// CHECK: scf.index_switch %[[ZERO]] -> i1
400+
// CHECK-LABEL: func.func @func1()
401+
// CHECK: %[[ZERO:.+]] = arith.constant 0 : index
402+
// CHECK: call @foo(%[[ZERO]]) : (index) -> ()
401403

402404
#map6 = affine_map<(d0, d1, d2) -> (d0 floordiv 64)>
403405
#map29 = affine_map<(d0, d1, d2) -> (d2 * 64 - 2, 5, (d1 mod 4) floordiv 8)>
404406
module {
407+
func.func private @foo(%0 : index) -> ()
408+
405409
func.func @func1() {
406410
%true = arith.constant true
407411
%c0 = arith.constant 0 : index
@@ -412,11 +416,7 @@ module {
412416
%alloc_249 = memref.alloc() : memref<7xf32>
413417
%135 = affine.apply #map6(%c15, %c0, %c14)
414418
%163 = affine.min #map29(%c5, %135, %c11)
415-
%196 = scf.index_switch %163 -> i1
416-
default {
417-
memref.assume_alignment %alloc_249, 1 : memref<7xf32>
418-
scf.yield %true : i1
419-
}
419+
func.call @foo(%163) : (index) -> ()
420420
return
421421
}
422422
}

0 commit comments

Comments
 (0)