Skip to content

[CIR] Upstream CmpOp #133159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return createAdd(loc, lhs, rhs, OverflowBehavior::NoUnsignedWrap);
}

cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind,
mlir::Value lhs, mlir::Value rhs) {
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}

//
// Block handling helpers
// ----------------------
Expand Down
41 changes: 41 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,47 @@ def ForOp : LoopOpBase<"for"> {
}];
}

//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//

def CmpOpKind_LT : I32EnumAttrCase<"lt", 1>;
def CmpOpKind_LE : I32EnumAttrCase<"le", 2>;
def CmpOpKind_GT : I32EnumAttrCase<"gt", 3>;
def CmpOpKind_GE : I32EnumAttrCase<"ge", 4>;
def CmpOpKind_EQ : I32EnumAttrCase<"eq", 5>;
def CmpOpKind_NE : I32EnumAttrCase<"ne", 6>;

def CmpOpKind : I32EnumAttr<
"CmpOpKind",
"compare operation kind",
[CmpOpKind_LT, CmpOpKind_LE, CmpOpKind_GT,
CmpOpKind_GE, CmpOpKind_EQ, CmpOpKind_NE]> {
let cppNamespace = "::cir";
}

def CmpOp : CIR_Op<"cmp", [Pure, SameTypeOperands]> {

let summary = "Compare values two values and produce a boolean result";
let description = [{
`cir.cmp` compares two input operands of the same type and produces a
`cir.bool` result. The kinds of comparison available are:
[lt,gt,ge,eq,ne]

```mlir
%7 = cir.cmp(gt, %1, %2) : i32, !cir.bool
```
}];

let results = (outs CIR_BoolType:$result);
let arguments = (ins Arg<CmpOpKind, "cmp kind">:$kind,
CIR_AnyType:$lhs, CIR_AnyType:$rhs);

let assemblyFormat = [{
`(` $kind `,` $lhs `,` $rhs `)` `:` type($lhs) `,` type($result) attr-dict
}];
}

//===----------------------------------------------------------------------===//
// BinOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 0 additions & 1 deletion clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ struct MissingFeatures {
static bool opGlobalViewAttr() { return false; }
static bool lowerModeOptLevel() { return false; }
static bool opTBAA() { return false; }
static bool opCmp() { return false; }
static bool objCLifetime() { return false; }
static bool emitNullabilityCheck() { return false; }
static bool astVarDeclInterface() { return false; }
Expand Down
79 changes: 79 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,85 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
HANDLEBINOP(Xor)
HANDLEBINOP(Or)
#undef HANDLEBINOP

mlir::Value emitCmp(const BinaryOperator *e) {
const mlir::Location loc = cgf.getLoc(e->getExprLoc());
mlir::Value result;
QualType lhsTy = e->getLHS()->getType();
QualType rhsTy = e->getRHS()->getType();

auto clangCmpToCIRCmp =
[](clang::BinaryOperatorKind clangCmp) -> cir::CmpOpKind {
switch (clangCmp) {
case BO_LT:
return cir::CmpOpKind::lt;
case BO_GT:
return cir::CmpOpKind::gt;
case BO_LE:
return cir::CmpOpKind::le;
case BO_GE:
return cir::CmpOpKind::ge;
case BO_EQ:
return cir::CmpOpKind::eq;
case BO_NE:
return cir::CmpOpKind::ne;
default:
llvm_unreachable("unsupported comparison kind for cir.cmp");
}
};

if (lhsTy->getAs<MemberPointerType>()) {
assert(!cir::MissingFeatures::dataMemberType());
assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
mlir::Value lhs = cgf.emitScalarExpr(e->getLHS());
mlir::Value rhs = cgf.emitScalarExpr(e->getRHS());
cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode());
result = builder.createCompare(loc, kind, lhs, rhs);
} else if (!lhsTy->isAnyComplexType() && !rhsTy->isAnyComplexType()) {
BinOpInfo boInfo = emitBinOps(e);
mlir::Value lhs = boInfo.lhs;
mlir::Value rhs = boInfo.rhs;

if (lhsTy->isVectorType()) {
assert(!cir::MissingFeatures::vectorType());
cgf.cgm.errorNYI(loc, "vector comparisons");
result = builder.getBool(false, loc);
} else if (boInfo.isFixedPointOp()) {
assert(!cir::MissingFeatures::fixedPointType());
cgf.cgm.errorNYI(loc, "fixed point comparisons");
result = builder.getBool(false, loc);
} else {
// integers and pointers
if (cgf.cgm.getCodeGenOpts().StrictVTablePointers &&
mlir::isa<cir::PointerType>(lhs.getType()) &&
mlir::isa<cir::PointerType>(rhs.getType())) {
cgf.cgm.errorNYI(loc, "strict vtable pointer comparisons");
}

cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode());
result = builder.createCompare(loc, kind, lhs, rhs);
}
} else {
// Complex Comparison: can only be an equality comparison.
assert(!cir::MissingFeatures::complexType());
cgf.cgm.errorNYI(loc, "complex comparison");
result = builder.getBool(false, loc);
}

return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),
e->getExprLoc());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bit of a mess how locations are emitted here:

cgf.getLoc(e->getExprLoc()) vs. cgf.getLoc(boInfo.loc) vs. cgf.getLoc(e->getSourceRange())

tbf I am not sure what is correct way, or whether clangir even has some rule of thumb how AST locations are translated to mlir locations. Just the inconsistency caught my eye. @bcardosolopes @andykaylor ?

Copy link
Contributor Author

@mmha mmha Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW classic codegen uses getExprLoc and doesnt emit a range.

In any case I get the Location at the beginning of the function now and store it in a variable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whether clangir even has some rule of thumb how AST locations are translated to mlir locations

we try to capture what makes more sense w.r.t source code and good diagnostic experience, but I wouldn't claim we have done a diligent process, so I can't attest for the quality. I fixed many bad source locations when working on the lifetime checker, so it's somewhat reliable, but I haven't tested much away from that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we need to look at context to decide which source location/range makes sense. I've seen problems where we were getting source locations from operands rather than from the expression where the operand is used, which was wrong in that case.


// Comparisons.
#define VISITCOMP(CODE) \
mlir::Value VisitBin##CODE(const BinaryOperator *E) { return emitCmp(E); }
VISITCOMP(LT)
VISITCOMP(GT)
VISITCOMP(LE)
VISITCOMP(GE)
VISITCOMP(EQ)
VISITCOMP(NE)
#undef VISITCOMP
};

LValue ScalarExprEmitter::emitCompoundAssignLValue(
Expand Down
82 changes: 82 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
Expand Down Expand Up @@ -1193,6 +1194,86 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
return mlir::LogicalResult::success();
}

/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
static mlir::LLVM::ICmpPredicate
convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) {
using CIR = cir::CmpOpKind;
using LLVMICmp = mlir::LLVM::ICmpPredicate;
switch (kind) {
case CIR::eq:
return LLVMICmp::eq;
case CIR::ne:
return LLVMICmp::ne;
case CIR::lt:
return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
case CIR::le:
return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
case CIR::gt:
return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
case CIR::ge:
return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
}
llvm_unreachable("Unknown CmpOpKind");
}

/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison
/// kind.
static mlir::LLVM::FCmpPredicate
convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) {
using CIR = cir::CmpOpKind;
using LLVMFCmp = mlir::LLVM::FCmpPredicate;
switch (kind) {
case CIR::eq:
return LLVMFCmp::oeq;
case CIR::ne:
return LLVMFCmp::une;
case CIR::lt:
return LLVMFCmp::olt;
case CIR::le:
return LLVMFCmp::ole;
case CIR::gt:
return LLVMFCmp::ogt;
case CIR::ge:
return LLVMFCmp::oge;
}
llvm_unreachable("Unknown CmpOpKind");
}

mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
cir::CmpOp cmpOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type type = cmpOp.getLhs().getType();

assert(!cir::MissingFeatures::dataMemberType());
assert(!cir::MissingFeatures::methodType());

// Lower to LLVM comparison op.
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
bool isSigned = mlir::isa<cir::IntType>(type)
? mlir::cast<cir::IntType>(type).isSigned()
: mlir::cast<mlir::IntegerType>(type).isSigned();
mlir::LLVM::ICmpPredicate kind =
convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
mlir::LLVM::ICmpPredicate kind =
convertCmpKindToICmpPredicate(cmpOp.getKind(),
/* isSigned=*/false);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
} else if (mlir::isa<cir::CIRFPTypeInterface>(type)) {
mlir::LLVM::FCmpPredicate kind =
convertCmpKindToFCmpPredicate(cmpOp.getKind());
rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
} else {
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
}

return mlir::success();
}

static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
mlir::DataLayout &dataLayout) {
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
Expand Down Expand Up @@ -1334,6 +1415,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMBinOpLowering,
CIRToLLVMBrCondOpLowering,
CIRToLLVMBrOpLowering,
CIRToLLVMCmpOpLowering,
CIRToLLVMConstantOpLowering,
CIRToLLVMFuncOpLowering,
CIRToLLVMTrapOpLowering,
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,19 @@ class CIRToLLVMBinOpLowering : public mlir::OpConversionPattern<cir::BinOp> {
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
public:
CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context)
: OpConversionPattern(typeConverter, context) {
setHasBoundedRewriteRecursion();
}

mlir::LogicalResult
matchAndRewrite(cir::CmpOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
public:
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;
Expand Down
23 changes: 15 additions & 8 deletions clang/test/CIR/CodeGen/cast.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -DCIR_ONLY %s -o %t.cir
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
Expand Down Expand Up @@ -57,16 +57,16 @@ int cStyleCasts_0(unsigned x1, int x2, float x3, short x4, double x5) {
// CIR: %{{[0-9]+}} = cir.cast(bool_to_int, %{{[0-9]+}} : !cir.bool), !s32i
// LLVM: %{{[0-9]+}} = zext i1 %{{[0-9]+}} to i32

#ifdef CIR_ONLY
bool b2 = x2; // int to bool
// CIR: %{{[0-9]+}} = cir.cast(int_to_bool, %{{[0-9]+}} : !s32i), !cir.bool
#endif
// LLVM: %[[INTTOBOOL:[0-9]+]] = icmp ne i32 %{{[0-9]+}}, 0
// LLVM: zext i1 %[[INTTOBOOL]] to i8

#ifdef CIR_ONLY
void *p;
bool b3 = p; // ptr to bool
bool b3 = p; // ptr to bool
// CIR: %{{[0-9]+}} = cir.cast(ptr_to_bool, %{{[0-9]+}} : !cir.ptr<!void>), !cir.bool
#endif
// LLVM: %[[PTRTOBOOL:[0-9]+]] = icmp ne ptr %{{[0-9]+}}, null
// LLVM: zext i1 %[[PTRTOBOOL]] to i8

float f;
bool b4 = f; // float to bool
Expand All @@ -77,7 +77,6 @@ int cStyleCasts_0(unsigned x1, int x2, float x3, short x4, double x5) {
return 0;
}

#ifdef CIR_ONLY
bool cptr(void *d) {
bool x = d;
return x;
Expand All @@ -88,7 +87,15 @@ bool cptr(void *d) {

// CIR: %[[DVAL:[0-9]+]] = cir.load %[[DPTR]] : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
// CIR: %{{[0-9]+}} = cir.cast(ptr_to_bool, %[[DVAL]] : !cir.ptr<!void>), !cir.bool
#endif

// LLVM-LABEL: define i1 @cptr(ptr %0)
// LLVM: %[[ARG_STORAGE:.*]] = alloca ptr, i64 1
// LLVM: %[[RETVAL:.*]] = alloca i8, i64 1
// LLVM: %[[X_STORAGE:.*]] = alloca i8, i64 1
// LLVM: store ptr %0, ptr %[[ARG_STORAGE]]
// LLVM: %[[LOADED_PTR:.*]] = load ptr, ptr %[[ARG_STORAGE]]
// LLVM: %[[NULL_CHECK:.*]] = icmp ne ptr %[[LOADED_PTR]], null
// LLVM: ret i1

void should_not_cast() {
unsigned x1;
Expand Down
Loading