14
14
#include " PassDetail.h"
15
15
#include " flang/Optimizer/Dialect/FIROps.h"
16
16
#include " flang/Optimizer/Dialect/FIRType.h"
17
+ #include " flang/Optimizer/Support/FIRContext.h"
17
18
#include " mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
18
19
#include " mlir/Conversion/LLVMCommon/Pattern.h"
19
20
#include " mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -487,6 +488,175 @@ struct InsertOnRangeOpConversion
487
488
return success ();
488
489
}
489
490
};
491
+
492
+ static mlir::Type getComplexEleTy (mlir::Type complex) {
493
+ if (auto cc = complex.dyn_cast <mlir::ComplexType>())
494
+ return cc.getElementType ();
495
+ return complex.cast <fir::ComplexType>().getElementType ();
496
+ }
497
+
498
+ //
499
+ // Primitive operations on Complex types
500
+ //
501
+
502
+ // / Generate inline code for complex addition/subtraction
503
+ template <typename LLVMOP, typename OPTY>
504
+ mlir::LLVM::InsertValueOp complexSum (OPTY sumop, mlir::ValueRange opnds,
505
+ mlir::ConversionPatternRewriter &rewriter,
506
+ fir::LLVMTypeConverter &lowering) {
507
+ mlir::Value a = opnds[0 ];
508
+ mlir::Value b = opnds[1 ];
509
+ auto loc = sumop.getLoc ();
510
+ auto ctx = sumop.getContext ();
511
+ auto c0 = mlir::ArrayAttr::get (ctx, rewriter.getI32IntegerAttr (0 ));
512
+ auto c1 = mlir::ArrayAttr::get (ctx, rewriter.getI32IntegerAttr (1 ));
513
+ mlir::Type eleTy = lowering.convertType (getComplexEleTy (sumop.getType ()));
514
+ mlir::Type ty = lowering.convertType (sumop.getType ());
515
+ auto x0 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
516
+ auto y0 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
517
+ auto x1 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
518
+ auto y1 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
519
+ auto rx = rewriter.create <LLVMOP>(loc, eleTy, x0, x1);
520
+ auto ry = rewriter.create <LLVMOP>(loc, eleTy, y0, y1);
521
+ auto r0 = rewriter.create <mlir::LLVM::UndefOp>(loc, ty);
522
+ auto r1 = rewriter.create <mlir::LLVM::InsertValueOp>(loc, ty, r0, rx, c0);
523
+ return rewriter.create <mlir::LLVM::InsertValueOp>(loc, ty, r1, ry, c1);
524
+ }
525
+
526
+ struct AddcOpConversion : public FIROpConversion <fir::AddcOp> {
527
+ using FIROpConversion::FIROpConversion;
528
+
529
+ mlir::LogicalResult
530
+ matchAndRewrite (fir::AddcOp addc, OpAdaptor adaptor,
531
+ mlir::ConversionPatternRewriter &rewriter) const override {
532
+ // given: (x + iy) + (x' + iy')
533
+ // result: (x + x') + i(y + y')
534
+ auto r = complexSum<mlir::LLVM::FAddOp>(addc, adaptor.getOperands (),
535
+ rewriter, lowerTy ());
536
+ rewriter.replaceOp (addc, r.getResult ());
537
+ return success ();
538
+ }
539
+ };
540
+
541
+ struct SubcOpConversion : public FIROpConversion <fir::SubcOp> {
542
+ using FIROpConversion::FIROpConversion;
543
+
544
+ mlir::LogicalResult
545
+ matchAndRewrite (fir::SubcOp subc, OpAdaptor adaptor,
546
+ mlir::ConversionPatternRewriter &rewriter) const override {
547
+ // given: (x + iy) - (x' + iy')
548
+ // result: (x - x') + i(y - y')
549
+ auto r = complexSum<mlir::LLVM::FSubOp>(subc, adaptor.getOperands (),
550
+ rewriter, lowerTy ());
551
+ rewriter.replaceOp (subc, r.getResult ());
552
+ return success ();
553
+ }
554
+ };
555
+
556
+ // / Inlined complex multiply
557
+ struct MulcOpConversion : public FIROpConversion <fir::MulcOp> {
558
+ using FIROpConversion::FIROpConversion;
559
+
560
+ mlir::LogicalResult
561
+ matchAndRewrite (fir::MulcOp mulc, OpAdaptor adaptor,
562
+ mlir::ConversionPatternRewriter &rewriter) const override {
563
+ // TODO: Can we use a call to __muldc3 ?
564
+ // given: (x + iy) * (x' + iy')
565
+ // result: (xx'-yy')+i(xy'+yx')
566
+ mlir::Value a = adaptor.getOperands ()[0 ];
567
+ mlir::Value b = adaptor.getOperands ()[1 ];
568
+ auto loc = mulc.getLoc ();
569
+ auto *ctx = mulc.getContext ();
570
+ auto c0 = mlir::ArrayAttr::get (ctx, rewriter.getI32IntegerAttr (0 ));
571
+ auto c1 = mlir::ArrayAttr::get (ctx, rewriter.getI32IntegerAttr (1 ));
572
+ mlir::Type eleTy = convertType (getComplexEleTy (mulc.getType ()));
573
+ mlir::Type ty = convertType (mulc.getType ());
574
+ auto x0 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
575
+ auto y0 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
576
+ auto x1 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
577
+ auto y1 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
578
+ auto xx = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
579
+ auto yx = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
580
+ auto xy = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
581
+ auto ri = rewriter.create <mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
582
+ auto yy = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
583
+ auto rr = rewriter.create <mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
584
+ auto ra = rewriter.create <mlir::LLVM::UndefOp>(loc, ty);
585
+ auto r1 = rewriter.create <mlir::LLVM::InsertValueOp>(loc, ty, ra, rr, c0);
586
+ auto r0 = rewriter.create <mlir::LLVM::InsertValueOp>(loc, ty, r1, ri, c1);
587
+ rewriter.replaceOp (mulc, r0.getResult ());
588
+ return success ();
589
+ }
590
+ };
591
+
592
+ // / Inlined complex division
593
+ struct DivcOpConversion : public FIROpConversion <fir::DivcOp> {
594
+ using FIROpConversion::FIROpConversion;
595
+
596
+ mlir::LogicalResult
597
+ matchAndRewrite (fir::DivcOp divc, OpAdaptor adaptor,
598
+ mlir::ConversionPatternRewriter &rewriter) const override {
599
+ // TODO: Can we use a call to __divdc3 instead?
600
+ // Just generate inline code for now.
601
+ // given: (x + iy) / (x' + iy')
602
+ // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
603
+ mlir::Value a = adaptor.getOperands ()[0 ];
604
+ mlir::Value b = adaptor.getOperands ()[1 ];
605
+ auto loc = divc.getLoc ();
606
+ auto *ctx = divc.getContext ();
607
+ auto c0 = mlir::ArrayAttr::get (ctx, rewriter.getI32IntegerAttr (0 ));
608
+ auto c1 = mlir::ArrayAttr::get (ctx, rewriter.getI32IntegerAttr (1 ));
609
+ mlir::Type eleTy = convertType (getComplexEleTy (divc.getType ()));
610
+ mlir::Type ty = convertType (divc.getType ());
611
+ auto x0 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
612
+ auto y0 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
613
+ auto x1 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
614
+ auto y1 = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
615
+ auto xx = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
616
+ auto x1x1 = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
617
+ auto yx = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
618
+ auto xy = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
619
+ auto yy = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
620
+ auto y1y1 = rewriter.create <mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
621
+ auto d = rewriter.create <mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
622
+ auto rrn = rewriter.create <mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
623
+ auto rin = rewriter.create <mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
624
+ auto rr = rewriter.create <mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
625
+ auto ri = rewriter.create <mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
626
+ auto ra = rewriter.create <mlir::LLVM::UndefOp>(loc, ty);
627
+ auto r1 = rewriter.create <mlir::LLVM::InsertValueOp>(loc, ty, ra, rr, c0);
628
+ auto r0 = rewriter.create <mlir::LLVM::InsertValueOp>(loc, ty, r1, ri, c1);
629
+ rewriter.replaceOp (divc, r0.getResult ());
630
+ return success ();
631
+ }
632
+ };
633
+
634
+ // / Inlined complex negation
635
+ struct NegcOpConversion : public FIROpConversion <fir::NegcOp> {
636
+ using FIROpConversion::FIROpConversion;
637
+
638
+ mlir::LogicalResult
639
+ matchAndRewrite (fir::NegcOp neg, OpAdaptor adaptor,
640
+ mlir::ConversionPatternRewriter &rewriter) const override {
641
+ // given: -(x + iy)
642
+ // result: -x - iy
643
+ auto *ctxt = neg.getContext ();
644
+ auto eleTy = convertType (getComplexEleTy (neg.getType ()));
645
+ auto ty = convertType (neg.getType ());
646
+ auto loc = neg.getLoc ();
647
+ mlir::Value o0 = adaptor.getOperands ()[0 ];
648
+ auto c0 = mlir::ArrayAttr::get (ctxt, rewriter.getI32IntegerAttr (0 ));
649
+ auto c1 = mlir::ArrayAttr::get (ctxt, rewriter.getI32IntegerAttr (1 ));
650
+ auto rp = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, o0, c0);
651
+ auto ip = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, eleTy, o0, c1);
652
+ auto nrp = rewriter.create <mlir::LLVM::FNegOp>(loc, eleTy, rp);
653
+ auto nip = rewriter.create <mlir::LLVM::FNegOp>(loc, eleTy, ip);
654
+ auto r = rewriter.create <mlir::LLVM::InsertValueOp>(loc, ty, o0, nrp, c0);
655
+ rewriter.replaceOpWithNewOp <mlir::LLVM::InsertValueOp>(neg, ty, r, nip, c1);
656
+ return success ();
657
+ }
658
+ };
659
+
490
660
} // namespace
491
661
492
662
namespace {
@@ -501,15 +671,21 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
501
671
mlir::ModuleOp getModule () { return getOperation (); }
502
672
503
673
void runOnOperation () override final {
674
+ auto mod = getModule ();
675
+ if (!forcedTargetTriple.empty ()) {
676
+ fir::setTargetTriple (mod, forcedTargetTriple);
677
+ }
678
+
504
679
auto *context = getModule ().getContext ();
505
680
fir::LLVMTypeConverter typeConverter{getModule ()};
506
681
mlir::OwningRewritePatternList pattern (context);
507
- pattern.insert <
508
- AddrOfOpConversion, CallOpConversion, ExtractValueOpConversion,
509
- HasValueOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
510
- InsertValueOpConversion, SelectOpConversion, SelectRankOpConversion,
511
- UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>(
512
- typeConverter);
682
+ pattern.insert <AddcOpConversion, AddrOfOpConversion, CallOpConversion,
683
+ DivcOpConversion, ExtractValueOpConversion,
684
+ HasValueOpConversion, GlobalOpConversion,
685
+ InsertOnRangeOpConversion, InsertValueOpConversion,
686
+ NegcOpConversion, MulcOpConversion, SelectOpConversion,
687
+ SelectRankOpConversion, SubcOpConversion, UndefOpConversion,
688
+ UnreachableOpConversion, ZeroOpConversion>(typeConverter);
513
689
mlir::populateStdToLLVMConversionPatterns (typeConverter, pattern);
514
690
mlir::arith::populateArithmeticToLLVMConversionPatterns (typeConverter,
515
691
pattern);
0 commit comments