Skip to content

Commit 6464066

Browse files
authored
[CIR] Upstream CreateOp for ComplexType with folder (#143192)
This change adds support for the create op for ComplexType with folder and support for empty init list #141365
1 parent 1bc0b08 commit 6464066

File tree

16 files changed

+415
-12
lines changed

16 files changed

+415
-12
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,9 @@ def ConstComplexAttr : CIR_Attr<"ConstComplex", "const_complex",
307307
);
308308

309309
let builders = [
310-
AttrBuilderWithInferredContext<(ins "cir::ComplexType":$type,
311-
"mlir::TypedAttr":$real,
310+
AttrBuilderWithInferredContext<(ins "mlir::TypedAttr":$real,
312311
"mlir::TypedAttr":$imag), [{
312+
auto type = cir::ComplexType::get(real.getType());
313313
return $_get(type.getContext(), type, real, imag);
314314
}]>,
315315
];

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,4 +2353,36 @@ def BaseClassAddrOp : CIR_Op<"base_class_addr"> {
23532353
}];
23542354
}
23552355

2356+
//===----------------------------------------------------------------------===//
2357+
// ComplexCreateOp
2358+
//===----------------------------------------------------------------------===//
2359+
2360+
def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
2361+
let summary = "Create a complex value from its real and imaginary parts";
2362+
let description = [{
2363+
The `cir.complex.create` operation takes two operands that represent the
2364+
real and imaginary part of a complex number, and yields the complex number.
2365+
2366+
```mlir
2367+
%0 = cir.const #cir.fp<1.000000e+00> : !cir.double
2368+
%1 = cir.const #cir.fp<2.000000e+00> : !cir.double
2369+
%2 = cir.complex.create %0, %1 : !cir.double -> !cir.complex<!cir.double>
2370+
```
2371+
}];
2372+
2373+
let results = (outs CIR_ComplexType:$result);
2374+
let arguments = (ins
2375+
CIR_AnyIntOrFloatType:$real,
2376+
CIR_AnyIntOrFloatType:$imag
2377+
);
2378+
2379+
let assemblyFormat = [{
2380+
$real `,` $imag
2381+
`:` qualified(type($real)) `->` qualified(type($result)) attr-dict
2382+
}];
2383+
2384+
let hasVerifier = 1;
2385+
let hasFolder = 1;
2386+
}
2387+
23562388
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,8 @@ def CIRRecordType : Type<
600600

601601
def CIR_AnyType : AnyTypeOf<[
602602
CIR_VoidType, CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_IntType,
603-
CIR_AnyFloatType, CIR_PointerType, CIR_FuncType, CIR_RecordType
603+
CIR_AnyFloatType, CIR_PointerType, CIR_FuncType, CIR_RecordType,
604+
CIR_ComplexType
604605
]>;
605606

606607
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ struct MissingFeatures {
251251
// Future CIR operations
252252
static bool awaitOp() { return false; }
253253
static bool callOp() { return false; }
254-
static bool complexCreateOp() { return false; }
255254
static bool complexImagOp() { return false; }
256255
static bool complexRealOp() { return false; }
257256
static bool ifOp() { return false; }

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,12 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
360360
return CIRBaseBuilderTy::createStore(loc, val, dst.getPointer(), align);
361361
}
362362

363+
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
364+
mlir::Value imag) {
365+
auto resultComplexTy = cir::ComplexType::get(real.getType());
366+
return create<cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
367+
}
368+
363369
/// Create a cir.ptr_stride operation to get access to an array element.
364370
/// \p idx is the index of the element to access, \p shouldDecay is true if
365371
/// the result should decay to a pointer to the element type.

clang/lib/CIR/CodeGen/CIRGenDecl.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,13 @@ void CIRGenFunction::emitExprAsInit(const Expr *init, const ValueDecl *d,
499499
emitScalarInit(init, getLoc(d->getSourceRange()), lvalue);
500500
return;
501501
case cir::TEK_Complex: {
502-
cgm.errorNYI(init->getSourceRange(), "emitExprAsInit: complex type");
502+
mlir::Value complex = emitComplexExpr(init);
503+
if (capturedByInit)
504+
cgm.errorNYI(init->getSourceRange(),
505+
"emitExprAsInit: complex type captured by init");
506+
mlir::Location loc = getLoc(init->getExprLoc());
507+
emitStoreOfComplex(loc, complex, lvalue,
508+
/*isInit*/ true);
503509
return;
504510
}
505511
case cir::TEK_Aggregate:
@@ -593,8 +599,8 @@ void CIRGenFunction::emitDecl(const Decl &d) {
593599
// None of these decls require codegen support.
594600
return;
595601

596-
case Decl::Enum: // enum X;
597-
case Decl::Record: // struct/union/class X;
602+
case Decl::Enum: // enum X;
603+
case Decl::Record: // struct/union/class X;
598604
case Decl::CXXRecord: // struct/union/class X; [C++]
599605
case Decl::NamespaceAlias:
600606
case Decl::Using: // using X; [C++]

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,3 +1690,14 @@ mlir::Value CIRGenFunction::emitScalarConstant(
16901690
}
16911691
return builder.getConstant(getLoc(e->getSourceRange()), constant.getValue());
16921692
}
1693+
1694+
/// An LValue is a candidate for having its loads and stores be made atomic if
1695+
/// we are operating under /volatile:ms *and* the LValue itself is volatile and
1696+
/// performing such an operation can be performed without a libcall.
1697+
bool CIRGenFunction::isLValueSuitableForInlineAtomic(LValue lv) {
1698+
if (!cgm.getLangOpts().MSVolatile)
1699+
return false;
1700+
1701+
cgm.errorNYI("LValueSuitableForInlineAtomic LangOpts MSVolatile");
1702+
return false;
1703+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include "CIRGenBuilder.h"
2+
#include "CIRGenFunction.h"
3+
4+
#include "clang/AST/StmtVisitor.h"
5+
6+
using namespace clang;
7+
using namespace clang::CIRGen;
8+
9+
namespace {
10+
class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
11+
CIRGenFunction &cgf;
12+
CIRGenBuilderTy &builder;
13+
14+
public:
15+
explicit ComplexExprEmitter(CIRGenFunction &cgf)
16+
: cgf(cgf), builder(cgf.getBuilder()) {}
17+
18+
/// Store the specified real/imag parts into the
19+
/// specified value pointer.
20+
void emitStoreOfComplex(mlir::Location loc, mlir::Value val, LValue lv,
21+
bool isInit);
22+
23+
mlir::Value VisitInitListExpr(InitListExpr *e);
24+
};
25+
26+
} // namespace
27+
28+
static const ComplexType *getComplexType(QualType type) {
29+
type = type.getCanonicalType();
30+
if (const ComplexType *comp = dyn_cast<ComplexType>(type))
31+
return comp;
32+
return cast<ComplexType>(cast<AtomicType>(type)->getValueType());
33+
}
34+
35+
void ComplexExprEmitter::emitStoreOfComplex(mlir::Location loc, mlir::Value val,
36+
LValue lv, bool isInit) {
37+
if (lv.getType()->isAtomicType() ||
38+
(!isInit && cgf.isLValueSuitableForInlineAtomic(lv))) {
39+
cgf.cgm.errorNYI("StoreOfComplex with Atomic LV");
40+
return;
41+
}
42+
43+
const Address destAddr = lv.getAddress();
44+
builder.createStore(loc, val, destAddr);
45+
}
46+
47+
mlir::Value ComplexExprEmitter::VisitInitListExpr(InitListExpr *e) {
48+
mlir::Location loc = cgf.getLoc(e->getExprLoc());
49+
if (e->getNumInits() == 2) {
50+
mlir::Value real = cgf.emitScalarExpr(e->getInit(0));
51+
mlir::Value imag = cgf.emitScalarExpr(e->getInit(1));
52+
return builder.createComplexCreate(loc, real, imag);
53+
}
54+
55+
if (e->getNumInits() == 1) {
56+
cgf.cgm.errorNYI("Create Complex with InitList with size 1");
57+
return {};
58+
}
59+
60+
assert(e->getNumInits() == 0 && "Unexpected number of inits");
61+
QualType complexElemTy =
62+
e->getType()->castAs<clang::ComplexType>()->getElementType();
63+
mlir::Type complexElemLLVMTy = cgf.convertType(complexElemTy);
64+
mlir::TypedAttr defaultValue = builder.getZeroInitAttr(complexElemLLVMTy);
65+
auto complexAttr = cir::ConstComplexAttr::get(defaultValue, defaultValue);
66+
return builder.create<cir::ConstantOp>(loc, complexAttr);
67+
}
68+
69+
mlir::Value CIRGenFunction::emitComplexExpr(const Expr *e) {
70+
assert(e && getComplexType(e->getType()) &&
71+
"Invalid complex expression to emit");
72+
73+
return ComplexExprEmitter(*this).Visit(const_cast<Expr *>(e));
74+
}
75+
76+
void CIRGenFunction::emitStoreOfComplex(mlir::Location loc, mlir::Value v,
77+
LValue dest, bool isInit) {
78+
ComplexExprEmitter(*this).emitStoreOfComplex(loc, v, dest, isInit);
79+
}

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ class CIRGenFunction : public CIRGenTypeCache {
338338
PrototypeWrapper(const clang::ObjCMethodDecl *md) : p(md) {}
339339
};
340340

341+
bool isLValueSuitableForInlineAtomic(LValue lv);
342+
341343
/// An abstract representation of regular/ObjC call/message targets.
342344
class AbstractCallee {
343345
/// The function declaration of the callee.
@@ -860,6 +862,10 @@ class CIRGenFunction : public CIRGenTypeCache {
860862

861863
mlir::LogicalResult emitForStmt(const clang::ForStmt &s);
862864

865+
/// Emit the computation of the specified expression of complex type,
866+
/// returning the result.
867+
mlir::Value emitComplexExpr(const Expr *e);
868+
863869
void emitCompoundStmt(const clang::CompoundStmt &s);
864870

865871
void emitCompoundStmtWithoutScope(const clang::CompoundStmt &s);
@@ -961,6 +967,9 @@ class CIRGenFunction : public CIRGenTypeCache {
961967

962968
void emitStaticVarDecl(const VarDecl &d, cir::GlobalLinkageKind linkage);
963969

970+
void emitStoreOfComplex(mlir::Location loc, mlir::Value v, LValue dest,
971+
bool isInit);
972+
964973
void emitStoreOfScalar(mlir::Value value, Address addr, bool isVolatile,
965974
clang::QualType ty, bool isInit = false,
966975
bool isNontemporal = false);

clang/lib/CIR/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_clang_library(clangCIR
1919
CIRGenDeclOpenACC.cpp
2020
CIRGenExpr.cpp
2121
CIRGenExprAggregate.cpp
22+
CIRGenExprComplex.cpp
2223
CIRGenExprConstant.cpp
2324
CIRGenExprScalar.cpp
2425
CIRGenFunction.cpp

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,33 @@ OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
17481748
vecTy, mlir::ArrayAttr::get(getContext(), elements));
17491749
}
17501750

1751+
//===----------------------------------------------------------------------===//
1752+
// ComplexCreateOp
1753+
//===----------------------------------------------------------------------===//
1754+
1755+
LogicalResult cir::ComplexCreateOp::verify() {
1756+
if (getType().getElementType() != getReal().getType()) {
1757+
emitOpError()
1758+
<< "operand type of cir.complex.create does not match its result type";
1759+
return failure();
1760+
}
1761+
1762+
return success();
1763+
}
1764+
1765+
OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
1766+
mlir::Attribute real = adaptor.getReal();
1767+
mlir::Attribute imag = adaptor.getImag();
1768+
if (!real || !imag)
1769+
return {};
1770+
1771+
// When both of real and imag are constants, we can fold the operation into an
1772+
// `#cir.const_complex` operation.
1773+
auto realAttr = mlir::cast<mlir::TypedAttr>(real);
1774+
auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
1775+
return cir::ConstComplexAttr::get(realAttr, imagAttr);
1776+
}
1777+
17511778
//===----------------------------------------------------------------------===//
17521779
// TableGen'd op method definitions
17531780
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,15 @@ void CIRCanonicalizePass::runOnOperation() {
134134
getOperation()->walk([&](Operation *op) {
135135
assert(!cir::MissingFeatures::switchOp());
136136
assert(!cir::MissingFeatures::tryOp());
137-
assert(!cir::MissingFeatures::complexCreateOp());
138137
assert(!cir::MissingFeatures::complexRealOp());
139138
assert(!cir::MissingFeatures::complexImagOp());
140139
assert(!cir::MissingFeatures::callOp());
141140

142141
// Many operations are here to perform a manual `fold` in
143142
// applyOpPatternsGreedily.
144143
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
145-
VecCreateOp, VecExtractOp, VecShuffleOp, VecShuffleDynamicOp,
146-
VecTernaryOp>(op))
144+
ComplexCreateOp, VecCreateOp, VecExtractOp, VecShuffleOp,
145+
VecShuffleDynamicOp, VecTernaryOp>(op))
147146
ops.push_back(op);
148147
});
149148

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,32 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
905905
rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter,
906906
getTypeConverter()));
907907
return mlir::success();
908-
} else {
908+
} else if (auto complexTy = mlir::dyn_cast<cir::ComplexType>(op.getType())) {
909+
auto complexAttr = mlir::cast<cir::ConstComplexAttr>(op.getValue());
910+
mlir::Type complexElemTy = complexTy.getElementType();
911+
mlir::Type complexElemLLVMTy = typeConverter->convertType(complexElemTy);
912+
913+
mlir::Attribute components[2];
914+
if (mlir::isa<cir::IntType>(complexElemTy)) {
915+
components[0] = rewriter.getIntegerAttr(
916+
complexElemLLVMTy,
917+
mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
918+
components[1] = rewriter.getIntegerAttr(
919+
complexElemLLVMTy,
920+
mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
921+
} else {
922+
components[0] = rewriter.getFloatAttr(
923+
complexElemLLVMTy,
924+
mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
925+
components[1] = rewriter.getFloatAttr(
926+
complexElemLLVMTy,
927+
mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
928+
}
929+
930+
attr = rewriter.getArrayAttr(components);
931+
}
932+
933+
else {
909934
return op.emitError() << "unsupported constant type " << op.getType();
910935
}
911936

@@ -1810,7 +1835,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
18101835
CIRToLLVMVecSplatOpLowering,
18111836
CIRToLLVMVecShuffleOpLowering,
18121837
CIRToLLVMVecShuffleDynamicOpLowering,
1813-
CIRToLLVMVecTernaryOpLowering
1838+
CIRToLLVMVecTernaryOpLowering,
1839+
CIRToLLVMComplexCreateOpLowering
18141840
// clang-format on
18151841
>(converter, patterns.getContext());
18161842

@@ -2096,6 +2122,24 @@ mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
20962122
return mlir::success();
20972123
}
20982124

2125+
mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
2126+
cir::ComplexCreateOp op, OpAdaptor adaptor,
2127+
mlir::ConversionPatternRewriter &rewriter) const {
2128+
mlir::Type complexLLVMTy =
2129+
getTypeConverter()->convertType(op.getResult().getType());
2130+
auto initialComplex =
2131+
rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy);
2132+
2133+
auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
2134+
op->getLoc(), initialComplex, adaptor.getReal(), 0);
2135+
2136+
auto complex = rewriter.create<mlir::LLVM::InsertValueOp>(
2137+
op->getLoc(), realComplex, adaptor.getImag(), 1);
2138+
2139+
rewriter.replaceOp(op, complex);
2140+
return mlir::success();
2141+
}
2142+
20992143
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
21002144
return std::make_unique<ConvertCIRToLLVMPass>();
21012145
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,16 @@ class CIRToLLVMVecTernaryOpLowering
408408
mlir::ConversionPatternRewriter &) const override;
409409
};
410410

411+
class CIRToLLVMComplexCreateOpLowering
412+
: public mlir::OpConversionPattern<cir::ComplexCreateOp> {
413+
public:
414+
using mlir::OpConversionPattern<cir::ComplexCreateOp>::OpConversionPattern;
415+
416+
mlir::LogicalResult
417+
matchAndRewrite(cir::ComplexCreateOp op, OpAdaptor,
418+
mlir::ConversionPatternRewriter &) const override;
419+
};
420+
411421
} // namespace direct
412422
} // namespace cir
413423

0 commit comments

Comments
 (0)