Skip to content

Commit 7b5132d

Browse files
clementvaljeanPerierschweitzpgi
committed
[fir] Add complex operations conversion from FIR LLVM IR
This patch add conversion for primitive operations on complex types. - fir.addc - fir.subc - fir.mulc - fir.divc - fir.negc This adds also the type conversion for !fir.complex<KIND> type. This patch is part of the upstreaming effort from fir-dev branch. This patch was updated to avoid failure on windows buildbot. Flang codegen does not support windows target so we force the test to use a known target instead. Reviewed By: kiranchandramohan, rovka Differential Revision: https://reviews.llvm.org/D113434 Co-authored-by: Jean Perier <[email protected]> Co-authored-by: Eric Schweitz <[email protected]>
1 parent 116dc70 commit 7b5132d

File tree

7 files changed

+415
-8
lines changed

7 files changed

+415
-8
lines changed

flang/include/flang/Optimizer/CodeGen/CGPasses.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def FIRToLLVMLowering : Pass<"fir-to-llvm-ir", "mlir::ModuleOp"> {
2424
}];
2525
let constructor = "::fir::createFIRToLLVMPass()";
2626
let dependentDialects = ["mlir::LLVM::LLVMDialect"];
27+
let options = [
28+
Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
29+
"Override module's target triple.">
30+
];
2731
}
2832

2933
def CodeGenRewrite : Pass<"cg-rewrite"> {

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 182 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "PassDetail.h"
1515
#include "flang/Optimizer/Dialect/FIROps.h"
1616
#include "flang/Optimizer/Dialect/FIRType.h"
17+
#include "flang/Optimizer/Support/FIRContext.h"
1718
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
1819
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1920
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -487,6 +488,175 @@ struct InsertOnRangeOpConversion
487488
return success();
488489
}
489490
};
491+
492+
static mlir::Type getComplexEleTy(mlir::Type complex) {
493+
if (auto cc = complex.dyn_cast<mlir::ComplexType>())
494+
return cc.getElementType();
495+
return complex.cast<fir::ComplexType>().getElementType();
496+
}
497+
498+
//
499+
// Primitive operations on Complex types
500+
//
501+
502+
/// Generate inline code for complex addition/subtraction
503+
template <typename LLVMOP, typename OPTY>
504+
mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds,
505+
mlir::ConversionPatternRewriter &rewriter,
506+
fir::LLVMTypeConverter &lowering) {
507+
mlir::Value a = opnds[0];
508+
mlir::Value b = opnds[1];
509+
auto loc = sumop.getLoc();
510+
auto ctx = sumop.getContext();
511+
auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0));
512+
auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1));
513+
mlir::Type eleTy = lowering.convertType(getComplexEleTy(sumop.getType()));
514+
mlir::Type ty = lowering.convertType(sumop.getType());
515+
auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
516+
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
517+
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
518+
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
519+
auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
520+
auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
521+
auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
522+
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r0, rx, c0);
523+
return rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ry, c1);
524+
}
525+
526+
struct AddcOpConversion : public FIROpConversion<fir::AddcOp> {
527+
using FIROpConversion::FIROpConversion;
528+
529+
mlir::LogicalResult
530+
matchAndRewrite(fir::AddcOp addc, OpAdaptor adaptor,
531+
mlir::ConversionPatternRewriter &rewriter) const override {
532+
// given: (x + iy) + (x' + iy')
533+
// result: (x + x') + i(y + y')
534+
auto r = complexSum<mlir::LLVM::FAddOp>(addc, adaptor.getOperands(),
535+
rewriter, lowerTy());
536+
rewriter.replaceOp(addc, r.getResult());
537+
return success();
538+
}
539+
};
540+
541+
struct SubcOpConversion : public FIROpConversion<fir::SubcOp> {
542+
using FIROpConversion::FIROpConversion;
543+
544+
mlir::LogicalResult
545+
matchAndRewrite(fir::SubcOp subc, OpAdaptor adaptor,
546+
mlir::ConversionPatternRewriter &rewriter) const override {
547+
// given: (x + iy) - (x' + iy')
548+
// result: (x - x') + i(y - y')
549+
auto r = complexSum<mlir::LLVM::FSubOp>(subc, adaptor.getOperands(),
550+
rewriter, lowerTy());
551+
rewriter.replaceOp(subc, r.getResult());
552+
return success();
553+
}
554+
};
555+
556+
/// Inlined complex multiply
557+
struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
558+
using FIROpConversion::FIROpConversion;
559+
560+
mlir::LogicalResult
561+
matchAndRewrite(fir::MulcOp mulc, OpAdaptor adaptor,
562+
mlir::ConversionPatternRewriter &rewriter) const override {
563+
// TODO: Can we use a call to __muldc3 ?
564+
// given: (x + iy) * (x' + iy')
565+
// result: (xx'-yy')+i(xy'+yx')
566+
mlir::Value a = adaptor.getOperands()[0];
567+
mlir::Value b = adaptor.getOperands()[1];
568+
auto loc = mulc.getLoc();
569+
auto *ctx = mulc.getContext();
570+
auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0));
571+
auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1));
572+
mlir::Type eleTy = convertType(getComplexEleTy(mulc.getType()));
573+
mlir::Type ty = convertType(mulc.getType());
574+
auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
575+
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
576+
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
577+
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
578+
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
579+
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
580+
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
581+
auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
582+
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
583+
auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
584+
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
585+
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, ra, rr, c0);
586+
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ri, c1);
587+
rewriter.replaceOp(mulc, r0.getResult());
588+
return success();
589+
}
590+
};
591+
592+
/// Inlined complex division
593+
struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
594+
using FIROpConversion::FIROpConversion;
595+
596+
mlir::LogicalResult
597+
matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor,
598+
mlir::ConversionPatternRewriter &rewriter) const override {
599+
// TODO: Can we use a call to __divdc3 instead?
600+
// Just generate inline code for now.
601+
// given: (x + iy) / (x' + iy')
602+
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
603+
mlir::Value a = adaptor.getOperands()[0];
604+
mlir::Value b = adaptor.getOperands()[1];
605+
auto loc = divc.getLoc();
606+
auto *ctx = divc.getContext();
607+
auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0));
608+
auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1));
609+
mlir::Type eleTy = convertType(getComplexEleTy(divc.getType()));
610+
mlir::Type ty = convertType(divc.getType());
611+
auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
612+
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
613+
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
614+
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
615+
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
616+
auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
617+
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
618+
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
619+
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
620+
auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
621+
auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
622+
auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
623+
auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
624+
auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
625+
auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
626+
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
627+
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, ra, rr, c0);
628+
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ri, c1);
629+
rewriter.replaceOp(divc, r0.getResult());
630+
return success();
631+
}
632+
};
633+
634+
/// Inlined complex negation
635+
struct NegcOpConversion : public FIROpConversion<fir::NegcOp> {
636+
using FIROpConversion::FIROpConversion;
637+
638+
mlir::LogicalResult
639+
matchAndRewrite(fir::NegcOp neg, OpAdaptor adaptor,
640+
mlir::ConversionPatternRewriter &rewriter) const override {
641+
// given: -(x + iy)
642+
// result: -x - iy
643+
auto *ctxt = neg.getContext();
644+
auto eleTy = convertType(getComplexEleTy(neg.getType()));
645+
auto ty = convertType(neg.getType());
646+
auto loc = neg.getLoc();
647+
mlir::Value o0 = adaptor.getOperands()[0];
648+
auto c0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0));
649+
auto c1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1));
650+
auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, o0, c0);
651+
auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, o0, c1);
652+
auto nrp = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, rp);
653+
auto nip = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, ip);
654+
auto r = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, o0, nrp, c0);
655+
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(neg, ty, r, nip, c1);
656+
return success();
657+
}
658+
};
659+
490660
} // namespace
491661

492662
namespace {
@@ -501,15 +671,21 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
501671
mlir::ModuleOp getModule() { return getOperation(); }
502672

503673
void runOnOperation() override final {
674+
auto mod = getModule();
675+
if (!forcedTargetTriple.empty()) {
676+
fir::setTargetTriple(mod, forcedTargetTriple);
677+
}
678+
504679
auto *context = getModule().getContext();
505680
fir::LLVMTypeConverter typeConverter{getModule()};
506681
mlir::OwningRewritePatternList pattern(context);
507-
pattern.insert<
508-
AddrOfOpConversion, CallOpConversion, ExtractValueOpConversion,
509-
HasValueOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
510-
InsertValueOpConversion, SelectOpConversion, SelectRankOpConversion,
511-
UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>(
512-
typeConverter);
682+
pattern.insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion,
683+
DivcOpConversion, ExtractValueOpConversion,
684+
HasValueOpConversion, GlobalOpConversion,
685+
InsertOnRangeOpConversion, InsertValueOpConversion,
686+
NegcOpConversion, MulcOpConversion, SelectOpConversion,
687+
SelectRankOpConversion, SubcOpConversion, UndefOpConversion,
688+
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
513689
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
514690
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
515691
pattern);

flang/lib/Optimizer/CodeGen/Target.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ struct GenericTarget : public CodeGenSpecifics {
3535
using CodeGenSpecifics::CodeGenSpecifics;
3636
using AT = CodeGenSpecifics::Attributes;
3737

38+
mlir::Type complexMemoryType(mlir::Type eleTy) const override {
39+
assert(fir::isa_real(eleTy));
40+
// { t, t } struct of 2 eleTy
41+
mlir::TypeRange range = {eleTy, eleTy};
42+
return mlir::TupleType::get(eleTy.getContext(), range);
43+
}
44+
3845
Marshalling boxcharArgumentType(mlir::Type eleTy, bool sret) const override {
3946
CodeGenSpecifics::Marshalling marshal;
4047
auto idxTy = mlir::IntegerType::get(eleTy.getContext(), S::defaultWidth);

flang/lib/Optimizer/CodeGen/Target.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class CodeGenSpecifics {
6565
CodeGenSpecifics() = delete;
6666
virtual ~CodeGenSpecifics() {}
6767

68+
/// Type presentation of a `complex<ele>` type value in memory.
69+
virtual mlir::Type complexMemoryType(mlir::Type eleTy) const = 0;
70+
6871
/// Type representation of a `complex<eleTy>` type argument when passed by
6972
/// value. An argument value may need to be passed as a (safe) reference
7073
/// argument.

flang/lib/Optimizer/CodeGen/TypeConverter.h

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H
1515

1616
#include "DescriptorModel.h"
17+
#include "Target.h"
1718
#include "flang/Lower/Todo.h" // remove when TODO's are done
1819
#include "flang/Optimizer/Support/FIRContext.h"
1920
#include "flang/Optimizer/Support/KindMapping.h"
@@ -28,7 +29,10 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
2829
public:
2930
LLVMTypeConverter(mlir::ModuleOp module)
3031
: mlir::LLVMTypeConverter(module.getContext()),
31-
kindMapping(getKindMapping(module)) {
32+
kindMapping(getKindMapping(module)),
33+
specifics(CodeGenSpecifics::get(module.getContext(),
34+
getTargetTriple(module),
35+
getKindMapping(module))) {
3236
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
3337

3438
// Each conversion should return a value of type mlir::Type.
@@ -39,6 +43,10 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
3943
});
4044
addConversion(
4145
[&](fir::RecordType derived) { return convertRecordType(derived); });
46+
addConversion(
47+
[&](fir::ComplexType cmplx) { return convertComplexType(cmplx); });
48+
addConversion(
49+
[&](fir::RealType real) { return convertRealType(real.getFKind()); });
4250
addConversion(
4351
[&](fir::ReferenceType ref) { return convertPointerLike(ref); });
4452
addConversion(
@@ -140,6 +148,24 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
140148
/*isPacked=*/false));
141149
}
142150

151+
// Use the target specifics to figure out how to map complex to LLVM IR. The
152+
// use of complex values in function signatures is handled before conversion
153+
// to LLVM IR dialect here.
154+
//
155+
// fir.complex<T> | std.complex<T> --> llvm<"{t,t}">
156+
template <typename C>
157+
mlir::Type convertComplexType(C cmplx) {
158+
LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n');
159+
auto eleTy = cmplx.getElementType();
160+
return convertType(specifics->complexMemoryType(eleTy));
161+
}
162+
163+
// convert a front-end kind value to either a std or LLVM IR dialect type
164+
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
165+
mlir::Type convertRealType(fir::KindTy kind) {
166+
return fromRealTypeID(kindMapping.getRealTypeID(kind), kind);
167+
}
168+
143169
template <typename A>
144170
mlir::Type convertPointerLike(A &ty) {
145171
mlir::Type eleTy = ty.getEleTy();
@@ -187,8 +213,33 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
187213
return mlir::LLVM::LLVMPointerType::get(baseTy);
188214
}
189215

216+
/// Convert llvm::Type::TypeID to mlir::Type
217+
mlir::Type fromRealTypeID(llvm::Type::TypeID typeID, fir::KindTy kind) {
218+
switch (typeID) {
219+
case llvm::Type::TypeID::HalfTyID:
220+
return mlir::FloatType::getF16(&getContext());
221+
case llvm::Type::TypeID::BFloatTyID:
222+
return mlir::FloatType::getBF16(&getContext());
223+
case llvm::Type::TypeID::FloatTyID:
224+
return mlir::FloatType::getF32(&getContext());
225+
case llvm::Type::TypeID::DoubleTyID:
226+
return mlir::FloatType::getF64(&getContext());
227+
case llvm::Type::TypeID::X86_FP80TyID:
228+
return mlir::FloatType::getF80(&getContext());
229+
case llvm::Type::TypeID::FP128TyID:
230+
return mlir::FloatType::getF128(&getContext());
231+
default:
232+
emitError(UnknownLoc::get(&getContext()))
233+
<< "unsupported type: !fir.real<" << kind << ">";
234+
return {};
235+
}
236+
}
237+
238+
KindMapping &getKindMap() { return kindMapping; }
239+
190240
private:
191241
KindMapping kindMapping;
242+
std::unique_ptr<CodeGenSpecifics> specifics;
192243
};
193244

194245
} // namespace fir

0 commit comments

Comments
 (0)