Skip to content

Commit 80571d5

Browse files
committed
[CIR] Upstream ComplexRealPtrOp for ComplexType
1 parent 6464066 commit 80571d5

File tree

8 files changed

+150
-3
lines changed

8 files changed

+150
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,4 +2385,33 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
23852385
let hasFolder = 1;
23862386
}
23872387

2388+
//===----------------------------------------------------------------------===//
2389+
// ComplexRealPtrOp
2390+
//===----------------------------------------------------------------------===//
2391+
2392+
def ComplexRealPtrOp : CIR_Op<"complex.real_ptr", [Pure]> {
2393+
let summary = "Derive a pointer to the real part of a complex value";
2394+
let description = [{
2395+
`cir.complex.real_ptr` operation takes a pointer operand that points to a
2396+
complex value of type `!cir.complex` and yields a pointer to the real part
2397+
of the operand.
2398+
2399+
Example:
2400+
2401+
```mlir
2402+
%1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
2403+
```
2404+
}];
2405+
2406+
let results = (outs CIR_PtrToIntOrFloatType:$result);
2407+
let arguments = (ins CIR_PtrToComplexType:$operand);
2408+
2409+
let assemblyFormat = [{
2410+
$operand `:`
2411+
qualified(type($operand)) `->` qualified(type($result)) attr-dict
2412+
}];
2413+
2414+
let hasVerifier = 1;
2415+
}
2416+
23882417
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
159159
let cppFunctionName = "isAnyIntegerOrFloatingPointType";
160160
}
161161

162+
//===----------------------------------------------------------------------===//
163+
// Complex Type predicates
164+
//===----------------------------------------------------------------------===//
165+
166+
def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;
167+
162168
//===----------------------------------------------------------------------===//
163169
// Pointer Type predicates
164170
//===----------------------------------------------------------------------===//
@@ -180,6 +186,19 @@ class CIR_PtrToPtrTo<code type, string summary>
180186
: CIR_ConfinedType<CIR_AnyPtrType, [CIR_IsPtrToPtrToPred<type>],
181187
"pointer to pointer to " # summary>;
182188

189+
// Pointee type constraint bases
190+
class CIR_PointeePred<Pred pred> : SubstLeaves<"$_self",
191+
"::mlir::cast<::cir::PointerType>($_self).getPointee()", pred>;
192+
193+
class CIR_PtrToAnyOf<list<Type> types, string summary = "">
194+
: CIR_ConfinedType<CIR_AnyPtrType,
195+
[Or<!foreach(type, types, CIR_PointeePred<type.predicate>)>],
196+
!if(!empty(summary),
197+
"pointer to " # CIR_TypeSummaries<types>.value,
198+
summary)>;
199+
200+
class CIR_PtrToType<Type type> : CIR_PtrToAnyOf<[type]>;
201+
183202
// Void pointer type constraints
184203
def CIR_VoidPtrType
185204
: CIR_PtrTo<"::cir::VoidType", "void type">,
@@ -192,6 +211,11 @@ def CIR_PtrToVoidPtrType
192211
"$_builder.getType<" # cppType # ">("
193212
"cir::VoidType::get($_builder.getContext())))">;
194213

214+
// Pointer to type constraints
215+
def CIR_PtrToIntOrFloatType : CIR_PtrToType<CIR_AnyIntOrFloatType>;
216+
217+
def CIR_PtrToComplexType : CIR_PtrToType<CIR_AnyComplexType>;
218+
195219
//===----------------------------------------------------------------------===//
196220
// Vector Type predicates
197221
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,19 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
366366
return create<cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
367367
}
368368

369+
/// Create a cir.complex.real_ptr operation that derives a pointer to the real
370+
/// part of the complex value pointed to by the specified pointer value.
371+
mlir::Value createRealPtr(mlir::Location loc, mlir::Value value) {
372+
auto srcPtrTy = mlir::cast<cir::PointerType>(value.getType());
373+
auto srcComplexTy = mlir::cast<cir::ComplexType>(srcPtrTy.getPointee());
374+
return create<cir::ComplexRealPtrOp>(
375+
loc, getPointerTo(srcComplexTy.getElementType()), value);
376+
}
377+
378+
Address createRealPtr(mlir::Location loc, Address addr) {
379+
return Address{createRealPtr(loc, addr.getPointer()), addr.getAlignment()};
380+
}
381+
369382
/// Create a cir.ptr_stride operation to get access to an array element.
370383
/// \p idx is the index of the element to access, \p shouldDecay is true if
371384
/// the result should decay to a pointer to the element type.

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,29 @@ LValue CIRGenFunction::emitUnaryOpLValue(const UnaryOperator *e) {
541541
}
542542
case UO_Real:
543543
case UO_Imag: {
544-
cgm.errorNYI(e->getSourceRange(), "UnaryOp real/imag");
545-
return LValue();
544+
if (op == UO_Imag) {
545+
cgm.errorNYI(e->getSourceRange(), "UnaryOp real/imag");
546+
return LValue();
547+
}
548+
549+
LValue lv = emitLValue(e->getSubExpr());
550+
assert(lv.isSimple() && "real/imag on non-ordinary l-value");
551+
552+
// __real is valid on scalars. This is a faster way of testing that.
553+
// __imag can only produce an rvalue on scalars.
554+
if (e->getOpcode() == UO_Real &&
555+
!mlir::isa<cir::ComplexType>(lv.getAddress().getElementType())) {
556+
assert(e->getSubExpr()->getType()->isArithmeticType());
557+
return lv;
558+
}
559+
560+
QualType exprTy = getContext().getCanonicalType(e->getSubExpr()->getType());
561+
QualType elemTy = exprTy->castAs<clang::ComplexType>()->getElementType();
562+
mlir::Location loc = getLoc(e->getExprLoc());
563+
Address component = builder.createRealPtr(loc, lv.getAddress());
564+
LValue elemLV = makeAddrLValue(component, elemTy);
565+
elemLV.getQuals().addQualifiers(lv.getQuals());
566+
return elemLV;
546567
}
547568
case UO_PreInc:
548569
case UO_PreDec: {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,25 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
17751775
return cir::ConstComplexAttr::get(realAttr, imagAttr);
17761776
}
17771777

1778+
//===----------------------------------------------------------------------===//
1779+
// ComplexRealPtrOp
1780+
//===----------------------------------------------------------------------===//
1781+
1782+
LogicalResult cir::ComplexRealPtrOp::verify() {
1783+
mlir::Type resultPointeeTy = getType().getPointee();
1784+
cir::PointerType operandPtrTy = getOperand().getType();
1785+
auto operandPointeeTy =
1786+
mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
1787+
1788+
if (resultPointeeTy != operandPointeeTy.getElementType()) {
1789+
emitOpError()
1790+
<< "cir.complex.real_ptr result type does not match operand type";
1791+
return failure();
1792+
}
1793+
1794+
return success();
1795+
}
1796+
17781797
//===----------------------------------------------------------------------===//
17791798
// TableGen'd op method definitions
17801799
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1836,7 +1836,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
18361836
CIRToLLVMVecShuffleOpLowering,
18371837
CIRToLLVMVecShuffleDynamicOpLowering,
18381838
CIRToLLVMVecTernaryOpLowering,
1839-
CIRToLLVMComplexCreateOpLowering
1839+
CIRToLLVMComplexCreateOpLowering,
1840+
CIRToLLVMComplexRealPtrOpLowering
18401841
// clang-format on
18411842
>(converter, patterns.getContext());
18421843

@@ -2140,6 +2141,23 @@ mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
21402141
return mlir::success();
21412142
}
21422143

2144+
mlir::LogicalResult CIRToLLVMComplexRealPtrOpLowering::matchAndRewrite(
2145+
cir::ComplexRealPtrOp op, OpAdaptor adaptor,
2146+
mlir::ConversionPatternRewriter &rewriter) const {
2147+
cir::PointerType operandTy = op.getOperand().getType();
2148+
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2149+
mlir::Type elementLLVMTy =
2150+
getTypeConverter()->convertType(operandTy.getPointee());
2151+
2152+
mlir::LLVM::GEPArg gepIndices[2] = {{0}, {0}};
2153+
mlir::LLVM::GEPNoWrapFlags inboundsNuw =
2154+
mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
2155+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
2156+
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
2157+
inboundsNuw);
2158+
return mlir::success();
2159+
}
2160+
21432161
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
21442162
return std::make_unique<ConvertCIRToLLVMPass>();
21452163
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,16 @@ class CIRToLLVMComplexCreateOpLowering
418418
mlir::ConversionPatternRewriter &) const override;
419419
};
420420

421+
class CIRToLLVMComplexRealPtrOpLowering
422+
: public mlir::OpConversionPattern<cir::ComplexRealPtrOp> {
423+
public:
424+
using mlir::OpConversionPattern<cir::ComplexRealPtrOp>::OpConversionPattern;
425+
426+
mlir::LogicalResult
427+
matchAndRewrite(cir::ComplexRealPtrOp op, OpAdaptor,
428+
mlir::ConversionPatternRewriter &) const override;
429+
};
430+
421431
} // namespace direct
422432
} // namespace cir
423433

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,16 @@ void foo7() {
176176
// OGCG: store float %[[TMP_A]], ptr %[[C_REAL_PTR]], align 4
177177
// OGCG: store float 2.000000e+00, ptr %[[C_IMAG_PTR]], align 4
178178

179+
void foo10() {
180+
double _Complex c;
181+
double *realPtr = &__real__ c;
182+
}
183+
184+
// CIR: %[[COMPLEX:.*]] = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
185+
// CIR: %[[REAL_PTR:.*]] = cir.complex.real_ptr %[[COMPLEX]] : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
186+
187+
// LLVM: %[[COMPLEX:.*]] = alloca { double, double }, i64 1, align 8
188+
// LLVM: %[[REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0
189+
190+
// OGCG: %[[COMPLEX:.*]] = alloca { double, double }, align 8
191+
// OGCG: %[[REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0

0 commit comments

Comments
 (0)