Skip to content

Commit 1282556

Browse files
[SYCL-MLIR] Fix SelectI1Ext canonicalize pattern (#8290)
For the code below, `foo(true)` should return `1`, and `foo(false)` should return `2`. ``` func.func @foo(%arg0: i1) -> i32 { %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 %0 = arith.select %arg0, %c1_i32, %c2_i32 : i32 return %0 : i32 } ``` Without this patch, the `SelectI1Ext` canonicalization pattern (incorrectly) would change it to the following code: ``` func.func @foo(%arg0: i1) -> i32 { %0 = arith.extui %arg0 : i1 to i32 return %0 : i32 } ``` The canonicalization pattern is only legal if the 2 operands of the select operation are representable by a `i1` type (so only for the constant 0 and 1). Signed-off-by: Tsang, Whitney <[email protected]>
1 parent c50573a commit 1282556

File tree

3 files changed

+56
-26
lines changed

3 files changed

+56
-26
lines changed

polygeist/lib/polygeist/Ops.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,33 +2045,33 @@ class SelectI1Ext final : public OpRewritePattern<arith::SelectOp> {
20452045

20462046
LogicalResult matchAndRewrite(arith::SelectOp op,
20472047
PatternRewriter &rewriter) const override {
2048-
auto ty = op.getType().dyn_cast<IntegerType>();
2049-
if (!ty)
2050-
return failure();
2051-
if (ty.getWidth() == 1)
2052-
return failure();
2053-
IntegerAttr lhs, rhs;
2054-
Value lhs_v = nullptr, rhs_v = nullptr;
2055-
if (auto ext = op.getTrueValue().getDefiningOp<arith::ExtUIOp>()) {
2056-
lhs_v = ext.getIn();
2057-
if (lhs_v.getType().cast<IntegerType>().getWidth() != 1)
2058-
return failure();
2059-
} else if (matchPattern(op.getTrueValue(), m_Constant(&lhs))) {
2060-
} else
2048+
// Cannot extui i1 to i1, or i1 to f32
2049+
if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
20612050
return failure();
20622051

2063-
if (auto ext = op.getFalseValue().getDefiningOp<arith::ExtUIOp>()) {
2064-
rhs_v = ext.getIn();
2065-
if (rhs_v.getType().cast<IntegerType>().getWidth() != 1)
2066-
return failure();
2067-
} else if (matchPattern(op.getFalseValue(), m_Constant(&rhs))) {
2068-
} else
2069-
return failure();
2052+
// Determines whether the given value fits into a boolean type.
2053+
auto getI1 = [&op, &rewriter](Value val) -> Value {
2054+
constexpr int typeWidth = 1;
2055+
if (matchPattern(val, m_Op<arith::ExtUIOp>())) {
2056+
Value result = val.getDefiningOp()->getOperand(0);
2057+
if (result.getType().cast<IntegerType>().isInteger(typeWidth))
2058+
return result;
2059+
}
2060+
2061+
IntegerAttr intAttr;
2062+
if (matchPattern(val, m_Constant(&intAttr))) {
2063+
if (intAttr.getInt() == 0 || intAttr.getInt() == 1)
2064+
return rewriter.create<ConstantIntOp>(op.getLoc(), intAttr.getInt(),
2065+
typeWidth);
2066+
}
20702067

2071-
if (!lhs_v)
2072-
lhs_v = rewriter.create<ConstantIntOp>(op.getLoc(), lhs.getInt(), 1);
2073-
if (!rhs_v)
2074-
rhs_v = rewriter.create<ConstantIntOp>(op.getLoc(), rhs.getInt(), 1);
2068+
return nullptr;
2069+
};
2070+
2071+
Value lhs_v = getI1(op.getTrueValue());
2072+
Value rhs_v = getI1(op.getFalseValue());
2073+
if (!lhs_v || !rhs_v)
2074+
return failure();
20752075

20762076
rewriter.replaceOpWithNewOp<ExtUIOp>(
20772077
op, op.getType(),
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: polygeist-opt --canonicalize %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @test1(%arg0: i1) -> i32 {
4+
// CHECK: %0 = arith.select %arg0, %c1_i32, %c2_i32 : i32
5+
// CHECK-NEXT: return %0 : i32
6+
// CHECK-NEXT: }
7+
8+
func.func @test1(%arg0: i1) -> i32 {
9+
%c1_i32 = arith.constant 1 : i32
10+
%c2_i32 = arith.constant 2 : i32
11+
%0 = arith.select %arg0, %c1_i32, %c2_i32 : i32
12+
return %0 : i32
13+
}
14+
15+
// CHECK-LABEL: func.func @test2(%arg0: i1) -> i32 {
16+
// CHECK: %0 = arith.xori %arg0, %true : i1
17+
// CHECK-NEXT: %1 = arith.extui %0 : i1 to i32
18+
// CHECK-NEXT: return %1 : i32
19+
// CHECK-NEXT: }
20+
21+
func.func @test2(%arg0: i1) -> i32 {
22+
%true = arith.constant true
23+
%false = arith.constant false
24+
%0 = arith.extui %true : i1 to i32
25+
%1 = arith.extui %false : i1 to i32
26+
%2 = arith.select %arg0, %1, %0 : i32
27+
return %2 : i32
28+
}

polygeist/tools/cgeist/Test/Verification/sycl/structvec.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ SYCL_EXTERNAL structvec test_store(structvec sv, int idx, char el) {
7979
// CHECK-NEXT: }
8080

8181
// CHECK-LABEL: func.func @_ZN9structvecC1ESt16initializer_listIcE(%arg0: !llvm.ptr<struct<(vector<2xi8>)>, 4> {llvm.align = 2 : i64, llvm.dereferenceable_or_null = 2 : i64, llvm.noundef}, %arg1: !llvm.ptr<!llvm.struct<(memref<?xi8, 4>, i64)>> {llvm.align = 8 : i64, llvm.byval = !llvm.struct<(memref<?xi8, 4>, i64)>, llvm.noundef})
82-
// CHECK-NEXT: %c0_i8 = arith.constant 0 : i8
82+
// CHECK-DAG: %c-1_i32 = arith.constant -1 : i32
83+
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
84+
// CHECK-DAG: %c0_i8 = arith.constant 0 : i8
8385
// CHECK-NEXT: %0 = llvm.addrspacecast %arg1 : !llvm.ptr<!llvm.struct<(memref<?xi8, 4>, i64)>> to !llvm.ptr<!llvm.struct<(memref<?xi8, 4>, i64)>, 4>
8486
// CHECK-NEXT: %1 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr<struct<(vector<2xi8>)>, 4>) -> !llvm.ptr<vector<2xi8>, 4>
8587
// CHECK-NEXT: affine.for %arg2 = 0 to 2 {
@@ -88,7 +90,7 @@ SYCL_EXTERNAL structvec test_store(structvec sv, int idx, char el) {
8890
// CHECK-NEXT: %4 = arith.index_castui %2 : i32 to index
8991
// CHECK-NEXT: %5 = memref.load %3[%4] : memref<?xi8, 4>
9092
// CHECK-NEXT: %6 = arith.cmpi ne, %5, %c0_i8 : i8
91-
// CHECK-NEXT: %7 = arith.extui %6 : i1 to i32
93+
// CHECK-NEXT: %7 = arith.select %6, %c-1_i32, %c0_i32 : i32
9294
// CHECK-NEXT: %8 = arith.trunci %7 : i32 to i8
9395
// CHECK-NEXT: %9 = llvm.load %1 : !llvm.ptr<vector<2xi8>, 4>
9496
// CHECK-NEXT: %10 = vector.insertelement %8, %9[%2 : i32] : vector<2xi8>

0 commit comments

Comments
 (0)