Skip to content

Commit 0c81e6d

Browse files
authored
[mlir][llvm] Add icmp folder (llvm#65343)
This revision adds a simple icmp folder that performs the following folds to the LLVM dialect icmp op: - cmpi(eq/ne, x, x) -> true/false - cmpi(eq/ne, alloca, null) -> false/true - cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
1 parent 71b5f57 commit 0c81e6d

File tree

3 files changed

+99
-7
lines changed

3 files changed

+99
-7
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> {
139139
// Set the $predicate index to -1 to indicate there is no matching operand
140140
// and decrement the following indices.
141141
list<int> llvmArgIndices = [-1, 0, 1];
142+
let hasFolder = 1;
142143
}
143144

144145
// Other floating-point operations.
@@ -1561,6 +1562,17 @@ def LLVM_ConstantOp
15611562
}]>
15621563
];
15631564

1565+
let extraClassDeclaration = [{
1566+
/// Whether the constant op can be constructed with a particular value and
1567+
/// type.
1568+
static bool isBuildableWith(Attribute value, Type type);
1569+
1570+
/// Build the constant op with `value` and `type` if possible, otherwise
1571+
/// returns null.
1572+
static ConstantOp materialize(OpBuilder &builder, Attribute value,
1573+
Type type, Location loc);
1574+
}];
1575+
15641576
let hasFolder = 1;
15651577
let hasVerifier = 1;
15661578
}

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ static Type getI1SameShape(Type type) {
9898
}
9999

100100
//===----------------------------------------------------------------------===//
101-
// Printing, parsing and builder for LLVM::CmpOp.
101+
// Printing, parsing, folding and builder for LLVM::CmpOp.
102102
//===----------------------------------------------------------------------===//
103103

104104
void ICmpOp::print(OpAsmPrinter &p) {
@@ -175,6 +175,42 @@ ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
175175
return parseCmpOp<FCmpPredicate>(parser, result);
176176
}
177177

178+
/// Returns a scalar or vector boolean attribute of the given type.
179+
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
180+
auto boolAttr = BoolAttr::get(ctx, value);
181+
ShapedType shapedType = dyn_cast<ShapedType>(type);
182+
if (!shapedType)
183+
return boolAttr;
184+
return DenseElementsAttr::get(shapedType, boolAttr);
185+
}
186+
187+
OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
188+
if (getPredicate() != ICmpPredicate::eq &&
189+
getPredicate() != ICmpPredicate::ne)
190+
return {};
191+
192+
// cmpi(eq/ne, x, x) -> true/false
193+
if (getLhs() == getRhs())
194+
return getBoolAttribute(getType(), getContext(),
195+
getPredicate() == ICmpPredicate::eq);
196+
197+
// cmpi(eq/ne, alloca, null) -> false/true
198+
if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<NullOp>())
199+
return getBoolAttribute(getType(), getContext(),
200+
getPredicate() == ICmpPredicate::ne);
201+
202+
// cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
203+
if (getLhs().getDefiningOp<NullOp>() && getRhs().getDefiningOp<AllocaOp>()) {
204+
Value lhs = getLhs();
205+
Value rhs = getRhs();
206+
getLhsMutable().assign(rhs);
207+
getRhsMutable().assign(lhs);
208+
return getResult();
209+
}
210+
211+
return {};
212+
}
213+
178214
//===----------------------------------------------------------------------===//
179215
// Printing, parsing and verification for LLVM::AllocaOp.
180216
//===----------------------------------------------------------------------===//
@@ -2443,7 +2479,7 @@ Region *LLVMFuncOp::getCallableRegion() {
24432479
}
24442480

24452481
//===----------------------------------------------------------------------===//
2446-
// Verification for LLVM::ConstantOp.
2482+
// ConstantOp.
24472483
//===----------------------------------------------------------------------===//
24482484

24492485
LogicalResult LLVM::ConstantOp::verify() {
@@ -2503,6 +2539,25 @@ LogicalResult LLVM::ConstantOp::verify() {
25032539
return success();
25042540
}
25052541

2542+
bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
2543+
// The value's type must be the same as the provided type.
2544+
auto typedAttr = dyn_cast<TypedAttr>(value);
2545+
if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
2546+
return false;
2547+
// The value's type must be an LLVM compatible type.
2548+
if (!isCompatibleType(type))
2549+
return false;
2550+
// TODO: Add support for additional attributes kinds once needed.
2551+
return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
2552+
}
2553+
2554+
ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
2555+
Type type, Location loc) {
2556+
if (isBuildableWith(value, type))
2557+
return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value));
2558+
return nullptr;
2559+
}
2560+
25062561
// Constant op constant-folds to its value.
25072562
OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
25082563

@@ -3097,11 +3152,7 @@ LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
30973152

30983153
Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
30993154
Type type, Location loc) {
3100-
// TODO: Accept more possible attributes. So far, only IntegerAttr may come
3101-
// up.
3102-
if (!isa<IntegerAttr>(value))
3103-
return nullptr;
3104-
return builder.create<LLVM::ConstantOp>(loc, type, value);
3155+
return LLVM::ConstantOp::materialize(builder, value, type, loc);
31053156
}
31063157

31073158
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
// RUN: mlir-opt --pass-pipeline='builtin.module(llvm.func(canonicalize{test-convergence}))' %s -split-input-file | FileCheck %s
22

3+
// CHECK-LABEL: @fold_icmp_eq
4+
llvm.func @fold_icmp_eq(%arg0 : i32) -> i1 {
5+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
6+
%0 = llvm.icmp "eq" %arg0, %arg0 : i32
7+
// CHECK: llvm.return %[[C0]]
8+
llvm.return %0 : i1
9+
}
10+
11+
// CHECK-LABEL: @fold_icmp_ne
12+
llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> {
13+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(dense<false> : vector<2xi1>) : vector<2xi1>
14+
%0 = llvm.icmp "ne" %arg0, %arg0 : vector<2xi32>
15+
// CHECK: llvm.return %[[C0]]
16+
llvm.return %0 : vector<2xi1>
17+
}
18+
19+
// CHECK-LABEL: @fold_icmp_alloca
20+
llvm.func @fold_icmp_alloca() -> i1 {
21+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
22+
%c0 = llvm.mlir.null : !llvm.ptr
23+
%c1 = arith.constant 1 : i64
24+
%0 = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
25+
%1 = llvm.icmp "ne" %c0, %0 : !llvm.ptr
26+
// CHECK: llvm.return %[[C0]]
27+
llvm.return %1 : i1
28+
}
29+
30+
// -----
31+
332
// CHECK-LABEL: fold_extractvalue
433
llvm.func @fold_extractvalue() -> i32 {
534
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32

0 commit comments

Comments
 (0)