Skip to content

Commit 9d91b07

Browse files
authored
[CIR] Implement EqualOp for ComplexType (#145769)
This change adds support for equal operation for ComplexType #141365
1 parent 07e3c85 commit 9d91b07

File tree

5 files changed

+167
-13
lines changed

5 files changed

+167
-13
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2455,6 +2455,31 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
24552455
let hasFolder = 1;
24562456
}
24572457

2458+
//===----------------------------------------------------------------------===//
2459+
// ComplexEqualOp
2460+
//===----------------------------------------------------------------------===//
2461+
2462+
def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> {
2463+
2464+
let summary = "Computes whether two complex values are equal";
2465+
let description = [{
2466+
The `complex.equal` op takes two complex numbers and returns whether
2467+
they are equal.
2468+
2469+
```mlir
2470+
%r = cir.complex.eq %a, %b : !cir.complex<!cir.float>
2471+
```
2472+
}];
2473+
2474+
let results = (outs CIR_BoolType:$result);
2475+
let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
2476+
2477+
let assemblyFormat = [{
2478+
$lhs `,` $rhs
2479+
`:` qualified(type($lhs)) attr-dict
2480+
}];
2481+
}
2482+
24582483
//===----------------------------------------------------------------------===//
24592484
// Assume Operations
24602485
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -894,9 +894,17 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
894894
}
895895
} else {
896896
// Complex Comparison: can only be an equality comparison.
897-
assert(!cir::MissingFeatures::complexType());
898-
cgf.cgm.errorNYI(loc, "complex comparison");
899-
result = builder.getBool(false, loc);
897+
assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
898+
899+
BinOpInfo boInfo = emitBinOps(e);
900+
if (e->getOpcode() == BO_EQ) {
901+
result =
902+
builder.create<cir::ComplexEqualOp>(loc, boInfo.lhs, boInfo.rhs);
903+
} else {
904+
assert(!cir::MissingFeatures::complexType());
905+
cgf.cgm.errorNYI(loc, "complex not equal");
906+
result = builder.getBool(false, loc);
907+
}
900908
}
901909

902910
return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,29 +1900,30 @@ void ConvertCIRToLLVMPass::runOnOperation() {
19001900
CIRToLLVMBrOpLowering,
19011901
CIRToLLVMCallOpLowering,
19021902
CIRToLLVMCmpOpLowering,
1903+
CIRToLLVMComplexCreateOpLowering,
1904+
CIRToLLVMComplexEqualOpLowering,
1905+
CIRToLLVMComplexImagOpLowering,
1906+
CIRToLLVMComplexRealOpLowering,
19031907
CIRToLLVMConstantOpLowering,
19041908
CIRToLLVMExpectOpLowering,
19051909
CIRToLLVMFuncOpLowering,
19061910
CIRToLLVMGetGlobalOpLowering,
19071911
CIRToLLVMGetMemberOpLowering,
19081912
CIRToLLVMSelectOpLowering,
1909-
CIRToLLVMSwitchFlatOpLowering,
19101913
CIRToLLVMShiftOpLowering,
1911-
CIRToLLVMStackSaveOpLowering,
19121914
CIRToLLVMStackRestoreOpLowering,
1915+
CIRToLLVMStackSaveOpLowering,
1916+
CIRToLLVMSwitchFlatOpLowering,
19131917
CIRToLLVMTrapOpLowering,
19141918
CIRToLLVMUnaryOpLowering,
1919+
CIRToLLVMVecCmpOpLowering,
19151920
CIRToLLVMVecCreateOpLowering,
19161921
CIRToLLVMVecExtractOpLowering,
19171922
CIRToLLVMVecInsertOpLowering,
1918-
CIRToLLVMVecCmpOpLowering,
1919-
CIRToLLVMVecSplatOpLowering,
1920-
CIRToLLVMVecShuffleOpLowering,
19211923
CIRToLLVMVecShuffleDynamicOpLowering,
1922-
CIRToLLVMVecTernaryOpLowering,
1923-
CIRToLLVMComplexCreateOpLowering,
1924-
CIRToLLVMComplexRealOpLowering,
1925-
CIRToLLVMComplexImagOpLowering
1924+
CIRToLLVMVecShuffleOpLowering,
1925+
CIRToLLVMVecSplatOpLowering,
1926+
CIRToLLVMVecTernaryOpLowering
19261927
// clang-format on
19271928
>(converter, patterns.getContext());
19281929

@@ -2244,6 +2245,43 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
22442245
return mlir::success();
22452246
}
22462247

2248+
mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite(
2249+
cir::ComplexEqualOp op, OpAdaptor adaptor,
2250+
mlir::ConversionPatternRewriter &rewriter) const {
2251+
mlir::Value lhs = adaptor.getLhs();
2252+
mlir::Value rhs = adaptor.getRhs();
2253+
2254+
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2255+
mlir::Type complexElemTy =
2256+
getTypeConverter()->convertType(complexType.getElementType());
2257+
2258+
mlir::Location loc = op.getLoc();
2259+
auto lhsReal =
2260+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2261+
auto lhsImag =
2262+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2263+
auto rhsReal =
2264+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2265+
auto rhsImag =
2266+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2267+
2268+
if (complexElemTy.isInteger()) {
2269+
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2270+
loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
2271+
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2272+
loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
2273+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
2274+
return mlir::success();
2275+
}
2276+
2277+
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2278+
loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
2279+
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2280+
loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
2281+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
2282+
return mlir::success();
2283+
}
2284+
22472285
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
22482286
return std::make_unique<ConvertCIRToLLVMPass>();
22492287
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,16 @@ class CIRToLLVMComplexImagOpLowering
463463
mlir::ConversionPatternRewriter &) const override;
464464
};
465465

466+
class CIRToLLVMComplexEqualOpLowering
467+
: public mlir::OpConversionPattern<cir::ComplexEqualOp> {
468+
public:
469+
using mlir::OpConversionPattern<cir::ComplexEqualOp>::OpConversionPattern;
470+
471+
mlir::LogicalResult
472+
matchAndRewrite(cir::ComplexEqualOp op, OpAdaptor,
473+
mlir::ConversionPatternRewriter &) const override;
474+
};
475+
466476
} // namespace direct
467477
} // namespace cir
468478

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,4 +368,77 @@ int foo17(int _Complex a, int _Complex b) {
368368
// OGCG: %[[B_REAL:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
369369
// OGCG: %[[TMP_B:.*]] = load i32, ptr %[[B_REAL]], align 4
370370
// OGCG: %[[ADD:.*]] = add nsw i32 %[[TMP_A]], %[[TMP_B]]
371-
// OGCG: ret i32 %[[ADD]]
371+
// OGCG: ret i32 %[[ADD]]
372+
373+
bool foo18(int _Complex a, int _Complex b) {
374+
return a == b;
375+
}
376+
377+
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
378+
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
379+
// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i>
380+
381+
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
382+
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
383+
// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0
384+
// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1
385+
// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0
386+
// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1
387+
// LLVM: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]]
388+
// LLVM: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]]
389+
// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
390+
391+
// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
392+
// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
393+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0
394+
// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4
395+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1
396+
// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4
397+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
398+
// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4
399+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
400+
// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4
401+
// OGCG: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]]
402+
// OGCG: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]]
403+
// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
404+
405+
bool foo19(double _Complex a, double _Complex b) {
406+
return a == b;
407+
}
408+
409+
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
410+
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
411+
// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double>
412+
413+
// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
414+
// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8
415+
// LLVM: %[[A_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 0
416+
// LLVM: %[[A_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 1
417+
// LLVM: %[[B_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 0
418+
// LLVM: %[[B_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 1
419+
// LLVM: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]]
420+
// LLVM: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
421+
// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
422+
423+
// OGCG: %[[COMPLEX_A:.*]] = alloca { double, double }, align 8
424+
// OGCG: %[[COMPLEX_B:.*]] = alloca { double, double }, align 8
425+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
426+
// OGCG: store double {{.*}}, ptr %[[A_REAL_PTR]], align 8
427+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
428+
// OGCG: store double {{.*}}, ptr %[[A_IMAG_PTR]], align 8
429+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
430+
// OGCG: store double {{.*}}, ptr %[[B_REAL_PTR]], align 8
431+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
432+
// OGCG: store double {{.*}}, ptr %[[B_IMAG_PTR]], align 8
433+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
434+
// OGCG: %[[A_REAL:.*]] = load double, ptr %[[A_REAL_PTR]], align 8
435+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
436+
// OGCG: %[[A_IMAG:.*]] = load double, ptr %[[A_IMAG_PTR]], align 8
437+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
438+
// OGCG: %[[B_REAL:.*]] = load double, ptr %[[B_REAL_PTR]], align 8
439+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
440+
// OGCG: %[[B_IMAG:.*]] = load double, ptr %[[B_IMAG_PTR]], align 8
441+
// OGCG: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]]
442+
// OGCG: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
443+
// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
444+

0 commit comments

Comments
 (0)