Skip to content

Commit 441f879

Browse files
mmhaxlaukoandykaylor
authored
[CIR] Upstream CmpOp (#133159)
This patch adds support for comparison operators with ClangIR, both integral and floating point. --------- Co-authored-by: Morris Hafner <[email protected]> Co-authored-by: Henrich Lauko <[email protected]> Co-authored-by: Andy Kaylor <[email protected]>
1 parent 2713998 commit 441f879

File tree

9 files changed

+1064
-9
lines changed

9 files changed

+1064
-9
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
335335
return createAdd(loc, lhs, rhs, OverflowBehavior::NoUnsignedWrap);
336336
}
337337

338+
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind,
339+
mlir::Value lhs, mlir::Value rhs) {
340+
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
341+
}
342+
338343
//
339344
// Block handling helpers
340345
// ----------------------

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,47 @@ def ForOp : LoopOpBase<"for"> {
10011001
}];
10021002
}
10031003

1004+
//===----------------------------------------------------------------------===//
1005+
// CmpOp
1006+
//===----------------------------------------------------------------------===//
1007+
1008+
def CmpOpKind_LT : I32EnumAttrCase<"lt", 1>;
1009+
def CmpOpKind_LE : I32EnumAttrCase<"le", 2>;
1010+
def CmpOpKind_GT : I32EnumAttrCase<"gt", 3>;
1011+
def CmpOpKind_GE : I32EnumAttrCase<"ge", 4>;
1012+
def CmpOpKind_EQ : I32EnumAttrCase<"eq", 5>;
1013+
def CmpOpKind_NE : I32EnumAttrCase<"ne", 6>;
1014+
1015+
def CmpOpKind : I32EnumAttr<
1016+
"CmpOpKind",
1017+
"compare operation kind",
1018+
[CmpOpKind_LT, CmpOpKind_LE, CmpOpKind_GT,
1019+
CmpOpKind_GE, CmpOpKind_EQ, CmpOpKind_NE]> {
1020+
let cppNamespace = "::cir";
1021+
}
1022+
1023+
def CmpOp : CIR_Op<"cmp", [Pure, SameTypeOperands]> {
1024+
1025+
let summary = "Compare values two values and produce a boolean result";
1026+
let description = [{
1027+
`cir.cmp` compares two input operands of the same type and produces a
1028+
`cir.bool` result. The kinds of comparison available are:
1029+
[lt,gt,ge,eq,ne]
1030+
1031+
```mlir
1032+
%7 = cir.cmp(gt, %1, %2) : i32, !cir.bool
1033+
```
1034+
}];
1035+
1036+
let results = (outs CIR_BoolType:$result);
1037+
let arguments = (ins Arg<CmpOpKind, "cmp kind">:$kind,
1038+
CIR_AnyType:$lhs, CIR_AnyType:$rhs);
1039+
1040+
let assemblyFormat = [{
1041+
`(` $kind `,` $lhs `,` $rhs `)` `:` type($lhs) `,` type($result) attr-dict
1042+
}];
1043+
}
1044+
10041045
//===----------------------------------------------------------------------===//
10051046
// BinOp
10061047
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ struct MissingFeatures {
8888
static bool opGlobalViewAttr() { return false; }
8989
static bool lowerModeOptLevel() { return false; }
9090
static bool opTBAA() { return false; }
91-
static bool opCmp() { return false; }
9291
static bool objCLifetime() { return false; }
9392
static bool emitNullabilityCheck() { return false; }
9493
static bool astVarDeclInterface() { return false; }

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,85 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
707707
HANDLEBINOP(Xor)
708708
HANDLEBINOP(Or)
709709
#undef HANDLEBINOP
710+
711+
mlir::Value emitCmp(const BinaryOperator *e) {
712+
const mlir::Location loc = cgf.getLoc(e->getExprLoc());
713+
mlir::Value result;
714+
QualType lhsTy = e->getLHS()->getType();
715+
QualType rhsTy = e->getRHS()->getType();
716+
717+
auto clangCmpToCIRCmp =
718+
[](clang::BinaryOperatorKind clangCmp) -> cir::CmpOpKind {
719+
switch (clangCmp) {
720+
case BO_LT:
721+
return cir::CmpOpKind::lt;
722+
case BO_GT:
723+
return cir::CmpOpKind::gt;
724+
case BO_LE:
725+
return cir::CmpOpKind::le;
726+
case BO_GE:
727+
return cir::CmpOpKind::ge;
728+
case BO_EQ:
729+
return cir::CmpOpKind::eq;
730+
case BO_NE:
731+
return cir::CmpOpKind::ne;
732+
default:
733+
llvm_unreachable("unsupported comparison kind for cir.cmp");
734+
}
735+
};
736+
737+
if (lhsTy->getAs<MemberPointerType>()) {
738+
assert(!cir::MissingFeatures::dataMemberType());
739+
assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
740+
mlir::Value lhs = cgf.emitScalarExpr(e->getLHS());
741+
mlir::Value rhs = cgf.emitScalarExpr(e->getRHS());
742+
cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode());
743+
result = builder.createCompare(loc, kind, lhs, rhs);
744+
} else if (!lhsTy->isAnyComplexType() && !rhsTy->isAnyComplexType()) {
745+
BinOpInfo boInfo = emitBinOps(e);
746+
mlir::Value lhs = boInfo.lhs;
747+
mlir::Value rhs = boInfo.rhs;
748+
749+
if (lhsTy->isVectorType()) {
750+
assert(!cir::MissingFeatures::vectorType());
751+
cgf.cgm.errorNYI(loc, "vector comparisons");
752+
result = builder.getBool(false, loc);
753+
} else if (boInfo.isFixedPointOp()) {
754+
assert(!cir::MissingFeatures::fixedPointType());
755+
cgf.cgm.errorNYI(loc, "fixed point comparisons");
756+
result = builder.getBool(false, loc);
757+
} else {
758+
// integers and pointers
759+
if (cgf.cgm.getCodeGenOpts().StrictVTablePointers &&
760+
mlir::isa<cir::PointerType>(lhs.getType()) &&
761+
mlir::isa<cir::PointerType>(rhs.getType())) {
762+
cgf.cgm.errorNYI(loc, "strict vtable pointer comparisons");
763+
}
764+
765+
cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode());
766+
result = builder.createCompare(loc, kind, lhs, rhs);
767+
}
768+
} else {
769+
// Complex Comparison: can only be an equality comparison.
770+
assert(!cir::MissingFeatures::complexType());
771+
cgf.cgm.errorNYI(loc, "complex comparison");
772+
result = builder.getBool(false, loc);
773+
}
774+
775+
return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),
776+
e->getExprLoc());
777+
}
778+
779+
// Comparisons.
780+
#define VISITCOMP(CODE) \
781+
mlir::Value VisitBin##CODE(const BinaryOperator *E) { return emitCmp(E); }
782+
VISITCOMP(LT)
783+
VISITCOMP(GT)
784+
VISITCOMP(LE)
785+
VISITCOMP(GE)
786+
VISITCOMP(EQ)
787+
VISITCOMP(NE)
788+
#undef VISITCOMP
710789
};
711790

712791
LValue ScalarExprEmitter::emitCompoundAssignLValue(

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2222
#include "mlir/IR/BuiltinDialect.h"
2323
#include "mlir/IR/BuiltinOps.h"
24+
#include "mlir/IR/Types.h"
2425
#include "mlir/Pass/Pass.h"
2526
#include "mlir/Pass/PassManager.h"
2627
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
@@ -1193,6 +1194,86 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
11931194
return mlir::LogicalResult::success();
11941195
}
11951196

1197+
/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
1198+
static mlir::LLVM::ICmpPredicate
1199+
convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) {
1200+
using CIR = cir::CmpOpKind;
1201+
using LLVMICmp = mlir::LLVM::ICmpPredicate;
1202+
switch (kind) {
1203+
case CIR::eq:
1204+
return LLVMICmp::eq;
1205+
case CIR::ne:
1206+
return LLVMICmp::ne;
1207+
case CIR::lt:
1208+
return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
1209+
case CIR::le:
1210+
return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
1211+
case CIR::gt:
1212+
return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
1213+
case CIR::ge:
1214+
return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
1215+
}
1216+
llvm_unreachable("Unknown CmpOpKind");
1217+
}
1218+
1219+
/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison
1220+
/// kind.
1221+
static mlir::LLVM::FCmpPredicate
1222+
convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) {
1223+
using CIR = cir::CmpOpKind;
1224+
using LLVMFCmp = mlir::LLVM::FCmpPredicate;
1225+
switch (kind) {
1226+
case CIR::eq:
1227+
return LLVMFCmp::oeq;
1228+
case CIR::ne:
1229+
return LLVMFCmp::une;
1230+
case CIR::lt:
1231+
return LLVMFCmp::olt;
1232+
case CIR::le:
1233+
return LLVMFCmp::ole;
1234+
case CIR::gt:
1235+
return LLVMFCmp::ogt;
1236+
case CIR::ge:
1237+
return LLVMFCmp::oge;
1238+
}
1239+
llvm_unreachable("Unknown CmpOpKind");
1240+
}
1241+
1242+
mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
1243+
cir::CmpOp cmpOp, OpAdaptor adaptor,
1244+
mlir::ConversionPatternRewriter &rewriter) const {
1245+
mlir::Type type = cmpOp.getLhs().getType();
1246+
1247+
assert(!cir::MissingFeatures::dataMemberType());
1248+
assert(!cir::MissingFeatures::methodType());
1249+
1250+
// Lower to LLVM comparison op.
1251+
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
1252+
bool isSigned = mlir::isa<cir::IntType>(type)
1253+
? mlir::cast<cir::IntType>(type).isSigned()
1254+
: mlir::cast<mlir::IntegerType>(type).isSigned();
1255+
mlir::LLVM::ICmpPredicate kind =
1256+
convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
1257+
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
1258+
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1259+
} else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
1260+
mlir::LLVM::ICmpPredicate kind =
1261+
convertCmpKindToICmpPredicate(cmpOp.getKind(),
1262+
/* isSigned=*/false);
1263+
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
1264+
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1265+
} else if (mlir::isa<cir::CIRFPTypeInterface>(type)) {
1266+
mlir::LLVM::FCmpPredicate kind =
1267+
convertCmpKindToFCmpPredicate(cmpOp.getKind());
1268+
rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
1269+
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1270+
} else {
1271+
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
1272+
}
1273+
1274+
return mlir::success();
1275+
}
1276+
11961277
static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
11971278
mlir::DataLayout &dataLayout) {
11981279
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1334,6 +1415,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
13341415
CIRToLLVMBinOpLowering,
13351416
CIRToLLVMBrCondOpLowering,
13361417
CIRToLLVMBrOpLowering,
1418+
CIRToLLVMCmpOpLowering,
13371419
CIRToLLVMConstantOpLowering,
13381420
CIRToLLVMFuncOpLowering,
13391421
CIRToLLVMTrapOpLowering,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ class CIRToLLVMBinOpLowering : public mlir::OpConversionPattern<cir::BinOp> {
186186
mlir::ConversionPatternRewriter &) const override;
187187
};
188188

189+
class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
190+
public:
191+
CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter,
192+
mlir::MLIRContext *context)
193+
: OpConversionPattern(typeConverter, context) {
194+
setHasBoundedRewriteRecursion();
195+
}
196+
197+
mlir::LogicalResult
198+
matchAndRewrite(cir::CmpOp op, OpAdaptor,
199+
mlir::ConversionPatternRewriter &) const override;
200+
};
201+
189202
class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
190203
public:
191204
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;

clang/test/CIR/CodeGen/cast.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -DCIR_ONLY %s -o %t.cir
1+
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
22
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
44
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
@@ -57,16 +57,16 @@ int cStyleCasts_0(unsigned x1, int x2, float x3, short x4, double x5) {
5757
// CIR: %{{[0-9]+}} = cir.cast(bool_to_int, %{{[0-9]+}} : !cir.bool), !s32i
5858
// LLVM: %{{[0-9]+}} = zext i1 %{{[0-9]+}} to i32
5959

60-
#ifdef CIR_ONLY
6160
bool b2 = x2; // int to bool
6261
// CIR: %{{[0-9]+}} = cir.cast(int_to_bool, %{{[0-9]+}} : !s32i), !cir.bool
63-
#endif
62+
// LLVM: %[[INTTOBOOL:[0-9]+]] = icmp ne i32 %{{[0-9]+}}, 0
63+
// LLVM: zext i1 %[[INTTOBOOL]] to i8
6464

65-
#ifdef CIR_ONLY
6665
void *p;
67-
bool b3 = p; // ptr to bool
66+
bool b3 = p; // ptr to bool
6867
// CIR: %{{[0-9]+}} = cir.cast(ptr_to_bool, %{{[0-9]+}} : !cir.ptr<!void>), !cir.bool
69-
#endif
68+
// LLVM: %[[PTRTOBOOL:[0-9]+]] = icmp ne ptr %{{[0-9]+}}, null
69+
// LLVM: zext i1 %[[PTRTOBOOL]] to i8
7070

7171
float f;
7272
bool b4 = f; // float to bool
@@ -77,7 +77,6 @@ int cStyleCasts_0(unsigned x1, int x2, float x3, short x4, double x5) {
7777
return 0;
7878
}
7979

80-
#ifdef CIR_ONLY
8180
bool cptr(void *d) {
8281
bool x = d;
8382
return x;
@@ -88,7 +87,15 @@ bool cptr(void *d) {
8887

8988
// CIR: %[[DVAL:[0-9]+]] = cir.load %[[DPTR]] : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
9089
// CIR: %{{[0-9]+}} = cir.cast(ptr_to_bool, %[[DVAL]] : !cir.ptr<!void>), !cir.bool
91-
#endif
90+
91+
// LLVM-LABEL: define i1 @cptr(ptr %0)
92+
// LLVM: %[[ARG_STORAGE:.*]] = alloca ptr, i64 1
93+
// LLVM: %[[RETVAL:.*]] = alloca i8, i64 1
94+
// LLVM: %[[X_STORAGE:.*]] = alloca i8, i64 1
95+
// LLVM: store ptr %0, ptr %[[ARG_STORAGE]]
96+
// LLVM: %[[LOADED_PTR:.*]] = load ptr, ptr %[[ARG_STORAGE]]
97+
// LLVM: %[[NULL_CHECK:.*]] = icmp ne ptr %[[LOADED_PTR]], null
98+
// LLVM: ret i1
9299

93100
void should_not_cast() {
94101
unsigned x1;

0 commit comments

Comments
 (0)