Skip to content

Commit 516884f

Browse files
committed
[MLIR] Fix FloorDivSIOpConverter that was failing for index type after the arithmetic op refactor
ConstantOp should be used instead of ConstantIntOp to be able to support index type. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D112191
1 parent 95935e8 commit 516884f

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,12 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
8181
Type type = signedFloorDivIOp.getType();
8282
Value a = signedFloorDivIOp.lhs();
8383
Value b = signedFloorDivIOp.rhs();
84-
Value plusOne = rewriter.create<arith::ConstantIntOp>(loc, 1, type);
85-
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, type);
86-
Value minusOne = rewriter.create<arith::ConstantIntOp>(loc, -1, type);
84+
Value plusOne = rewriter.create<arith::ConstantOp>(
85+
loc, rewriter.getIntegerAttr(type, 1));
86+
Value zero = rewriter.create<arith::ConstantOp>(
87+
loc, rewriter.getIntegerAttr(type, 0));
88+
Value minusOne = rewriter.create<arith::ConstantOp>(
89+
loc, rewriter.getIntegerAttr(type, -1));
8790
// Compute x = (b<0) ? 1 : -1.
8891
Value compare =
8992
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);

mlir/test/Dialect/Arithmetic/expand-ops.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,36 @@ func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
3030

3131
// -----
3232

33+
// Test ceil divide with index type
34+
// CHECK-LABEL: func @ceildivi_index
35+
// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
36+
func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
37+
%res = arith.ceildivsi %arg0, %arg1 : index
38+
return %res : index
39+
40+
// CHECK: [[ONE:%.+]] = arith.constant 1 : index
41+
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
42+
// CHECK: [[MINONE:%.+]] = arith.constant -1 : index
43+
// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
44+
// CHECK: [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : index
45+
// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
46+
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
47+
// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
48+
// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
49+
// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
50+
// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
51+
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
52+
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
53+
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
54+
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
55+
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
56+
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
57+
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
58+
// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
59+
}
60+
61+
// -----
62+
3363
// Test floor divide with signed integer
3464
// CHECK-LABEL: func @floordivi
3565
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
@@ -54,3 +84,30 @@ func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
5484
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
5585
// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
5686
}
87+
88+
// -----
89+
90+
// Test floor divide with index type
91+
// CHECK-LABEL: func @floordivi_index
92+
// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
93+
func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
94+
%res = arith.floordivsi %arg0, %arg1 : index
95+
return %res : index
96+
// CHECK: [[ONE:%.+]] = arith.constant 1 : index
97+
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
98+
// CHECK: [[MIN1:%.+]] = arith.constant -1 : index
99+
// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
100+
// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : index
101+
// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index
102+
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
103+
// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index
104+
// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
105+
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
106+
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
107+
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
108+
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
109+
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
110+
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
111+
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
112+
// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index
113+
}

0 commit comments

Comments
 (0)