|
12 | 12 | ///
|
13 | 13 | /// A location to place directive utilities shared across multiple lowering
|
14 | 14 | /// files, e.g. utilities shared in OpenMP and OpenACC. The header file can
|
15 |
| -/// be used for both declarations and templated/inline implementations. |
| 15 | +/// be used for both declarations and templated/inline implementations |
16 | 16 | //===----------------------------------------------------------------------===//
|
17 | 17 |
|
18 | 18 | #ifndef FORTRAN_LOWER_DIRECTIVES_COMMON_H
|
19 | 19 | #define FORTRAN_LOWER_DIRECTIVES_COMMON_H
|
20 | 20 |
|
21 | 21 | #include "flang/Common/idioms.h"
|
| 22 | +#include "flang/Evaluate/tools.h" |
| 23 | +#include "flang/Lower/AbstractConverter.h" |
22 | 24 | #include "flang/Lower/Bridge.h"
|
23 | 25 | #include "flang/Lower/ConvertExpr.h"
|
24 | 26 | #include "flang/Lower/ConvertVariable.h"
|
25 | 27 | #include "flang/Lower/OpenACC.h"
|
26 | 28 | #include "flang/Lower/OpenMP.h"
|
27 | 29 | #include "flang/Lower/PFTBuilder.h"
|
28 | 30 | #include "flang/Lower/StatementContext.h"
|
| 31 | +#include "flang/Lower/Support/Utils.h" |
29 | 32 | #include "flang/Optimizer/Builder/BoxValue.h"
|
30 | 33 | #include "flang/Optimizer/Builder/FIRBuilder.h"
|
31 | 34 | #include "flang/Optimizer/Builder/Todo.h"
|
|
36 | 39 | #include "mlir/Dialect/OpenACC/OpenACC.h"
|
37 | 40 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
38 | 41 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
| 42 | +#include "mlir/IR/Value.h" |
39 | 43 | #include "llvm/Frontend/OpenMP/OMPConstants.h"
|
| 44 | +#include <list> |
40 | 45 | #include <type_traits>
|
41 | 46 |
|
42 | 47 | namespace Fortran {
|
@@ -611,6 +616,309 @@ void createEmptyRegionBlocks(
|
611 | 616 | }
|
612 | 617 | }
|
613 | 618 |
|
| 619 | +inline mlir::Value |
| 620 | +getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter, |
| 621 | + fir::FirOpBuilder &builder, |
| 622 | + Fortran::lower::SymbolRef sym, mlir::Location loc) { |
| 623 | + mlir::Value symAddr = converter.getSymbolAddress(sym); |
| 624 | + // TODO: Might need revisiting to handle for non-shared clauses |
| 625 | + if (!symAddr) { |
| 626 | + if (const auto *details = |
| 627 | + sym->detailsIf<Fortran::semantics::HostAssocDetails>()) |
| 628 | + symAddr = converter.getSymbolAddress(details->symbol()); |
| 629 | + } |
| 630 | + |
| 631 | + if (!symAddr) |
| 632 | + llvm::report_fatal_error("could not retrieve symbol address"); |
| 633 | + |
| 634 | + if (auto boxTy = |
| 635 | + fir::unwrapRefType(symAddr.getType()).dyn_cast<fir::BaseBoxType>()) { |
| 636 | + if (boxTy.getEleTy().isa<fir::RecordType>()) |
| 637 | + TODO(loc, "derived type"); |
| 638 | + |
| 639 | + // Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a |
| 640 | + // `fir.ref<fir.class<T>>` type. |
| 641 | + if (symAddr.getType().isa<fir::ReferenceType>()) |
| 642 | + return builder.create<fir::LoadOp>(loc, symAddr); |
| 643 | + } |
| 644 | + return symAddr; |
| 645 | +} |
| 646 | + |
| 647 | +/// Generate the bounds operation from the descriptor information. |
| 648 | +template <typename BoundsOp, typename BoundsType> |
| 649 | +llvm::SmallVector<mlir::Value> |
| 650 | +genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc, |
| 651 | + Fortran::lower::AbstractConverter &converter, |
| 652 | + fir::ExtendedValue dataExv, mlir::Value box) { |
| 653 | + llvm::SmallVector<mlir::Value> bounds; |
| 654 | + mlir::Type idxTy = builder.getIndexType(); |
| 655 | + mlir::Type boundTy = builder.getType<BoundsType>(); |
| 656 | + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
| 657 | + assert(box.getType().isa<fir::BaseBoxType>() && |
| 658 | + "expect fir.box or fir.class"); |
| 659 | + for (unsigned dim = 0; dim < dataExv.rank(); ++dim) { |
| 660 | + mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim); |
| 661 | + mlir::Value baseLb = |
| 662 | + fir::factory::readLowerBound(builder, loc, dataExv, dim, one); |
| 663 | + auto dimInfo = |
| 664 | + builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, d); |
| 665 | + mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0); |
| 666 | + mlir::Value ub = |
| 667 | + builder.create<mlir::arith::SubIOp>(loc, dimInfo.getExtent(), one); |
| 668 | + mlir::Value bound = |
| 669 | + builder.create<BoundsOp>(loc, boundTy, lb, ub, mlir::Value(), |
| 670 | + dimInfo.getByteStride(), true, baseLb); |
| 671 | + bounds.push_back(bound); |
| 672 | + } |
| 673 | + return bounds; |
| 674 | +} |
| 675 | + |
| 676 | +/// Generate bounds operation for base array without any subscripts |
| 677 | +/// provided. |
| 678 | +template <typename BoundsOp, typename BoundsType> |
| 679 | +llvm::SmallVector<mlir::Value> |
| 680 | +genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, |
| 681 | + Fortran::lower::AbstractConverter &converter, |
| 682 | + fir::ExtendedValue dataExv, mlir::Value baseAddr) { |
| 683 | + mlir::Type idxTy = builder.getIndexType(); |
| 684 | + mlir::Type boundTy = builder.getType<BoundsType>(); |
| 685 | + llvm::SmallVector<mlir::Value> bounds; |
| 686 | + |
| 687 | + if (dataExv.rank() == 0) |
| 688 | + return bounds; |
| 689 | + |
| 690 | + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
| 691 | + for (std::size_t dim = 0; dim < dataExv.rank(); ++dim) { |
| 692 | + mlir::Value baseLb = |
| 693 | + fir::factory::readLowerBound(builder, loc, dataExv, dim, one); |
| 694 | + mlir::Value ext = fir::factory::readExtent(builder, loc, dataExv, dim); |
| 695 | + mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0); |
| 696 | + |
| 697 | + // ub = extent - 1 |
| 698 | + mlir::Value ub = builder.create<mlir::arith::SubIOp>(loc, ext, one); |
| 699 | + mlir::Value bound = |
| 700 | + builder.create<BoundsOp>(loc, boundTy, lb, ub, ext, one, false, baseLb); |
| 701 | + bounds.push_back(bound); |
| 702 | + } |
| 703 | + return bounds; |
| 704 | +} |
| 705 | + |
| 706 | +/// Generate bounds operations for an array section when subscripts are |
| 707 | +/// provided. |
| 708 | +template <typename BoundsOp, typename BoundsType> |
| 709 | +llvm::SmallVector<mlir::Value> |
| 710 | +genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, |
| 711 | + Fortran::lower::AbstractConverter &converter, |
| 712 | + Fortran::lower::StatementContext &stmtCtx, |
| 713 | + const std::list<Fortran::parser::SectionSubscript> &subscripts, |
| 714 | + std::stringstream &asFortran, fir::ExtendedValue &dataExv, |
| 715 | + mlir::Value baseAddr) { |
| 716 | + int dimension = 0; |
| 717 | + mlir::Type idxTy = builder.getIndexType(); |
| 718 | + mlir::Type boundTy = builder.getType<BoundsType>(); |
| 719 | + llvm::SmallVector<mlir::Value> bounds; |
| 720 | + |
| 721 | + mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); |
| 722 | + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
| 723 | + for (const auto &subscript : subscripts) { |
| 724 | + if (const auto *triplet{ |
| 725 | + std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)}) { |
| 726 | + if (dimension != 0) |
| 727 | + asFortran << ','; |
| 728 | + mlir::Value lbound, ubound, extent; |
| 729 | + std::optional<std::int64_t> lval, uval; |
| 730 | + mlir::Value baseLb = |
| 731 | + fir::factory::readLowerBound(builder, loc, dataExv, dimension, one); |
| 732 | + bool defaultLb = baseLb == one; |
| 733 | + mlir::Value stride = one; |
| 734 | + bool strideInBytes = false; |
| 735 | + |
| 736 | + if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>()) { |
| 737 | + mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension); |
| 738 | + auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, |
| 739 | + baseAddr, d); |
| 740 | + stride = dimInfo.getByteStride(); |
| 741 | + strideInBytes = true; |
| 742 | + } |
| 743 | + |
| 744 | + const auto &lower{std::get<0>(triplet->t)}; |
| 745 | + if (lower) { |
| 746 | + lval = Fortran::semantics::GetIntValue(lower); |
| 747 | + if (lval) { |
| 748 | + if (defaultLb) { |
| 749 | + lbound = builder.createIntegerConstant(loc, idxTy, *lval - 1); |
| 750 | + } else { |
| 751 | + mlir::Value lb = builder.createIntegerConstant(loc, idxTy, *lval); |
| 752 | + lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb); |
| 753 | + } |
| 754 | + asFortran << *lval; |
| 755 | + } else { |
| 756 | + const Fortran::lower::SomeExpr *lexpr = |
| 757 | + Fortran::semantics::GetExpr(*lower); |
| 758 | + mlir::Value lb = |
| 759 | + fir::getBase(converter.genExprValue(loc, *lexpr, stmtCtx)); |
| 760 | + lb = builder.createConvert(loc, baseLb.getType(), lb); |
| 761 | + lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb); |
| 762 | + asFortran << lexpr->AsFortran(); |
| 763 | + } |
| 764 | + } else { |
| 765 | + lbound = defaultLb ? zero : baseLb; |
| 766 | + } |
| 767 | + asFortran << ':'; |
| 768 | + const auto &upper{std::get<1>(triplet->t)}; |
| 769 | + if (upper) { |
| 770 | + uval = Fortran::semantics::GetIntValue(upper); |
| 771 | + if (uval) { |
| 772 | + if (defaultLb) { |
| 773 | + ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1); |
| 774 | + } else { |
| 775 | + mlir::Value ub = builder.createIntegerConstant(loc, idxTy, *uval); |
| 776 | + ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb); |
| 777 | + } |
| 778 | + asFortran << *uval; |
| 779 | + } else { |
| 780 | + const Fortran::lower::SomeExpr *uexpr = |
| 781 | + Fortran::semantics::GetExpr(*upper); |
| 782 | + mlir::Value ub = |
| 783 | + fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx)); |
| 784 | + ub = builder.createConvert(loc, baseLb.getType(), ub); |
| 785 | + ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb); |
| 786 | + asFortran << uexpr->AsFortran(); |
| 787 | + } |
| 788 | + } |
| 789 | + if (lower && upper) { |
| 790 | + if (lval && uval && *uval < *lval) { |
| 791 | + mlir::emitError(loc, "zero sized array section"); |
| 792 | + break; |
| 793 | + } else if (std::get<2>(triplet->t)) { |
| 794 | + const auto &strideExpr{std::get<2>(triplet->t)}; |
| 795 | + if (strideExpr) { |
| 796 | + mlir::emitError(loc, "stride cannot be specified on " |
| 797 | + "an OpenMP array section"); |
| 798 | + break; |
| 799 | + } |
| 800 | + } |
| 801 | + } |
| 802 | + // ub = baseLb + extent - 1 |
| 803 | + if (!ubound) { |
| 804 | + mlir::Value ext = |
| 805 | + fir::factory::readExtent(builder, loc, dataExv, dimension); |
| 806 | + mlir::Value lbExt = |
| 807 | + builder.create<mlir::arith::AddIOp>(loc, ext, baseLb); |
| 808 | + ubound = builder.create<mlir::arith::SubIOp>(loc, lbExt, one); |
| 809 | + } |
| 810 | + mlir::Value bound = builder.create<BoundsOp>( |
| 811 | + loc, boundTy, lbound, ubound, extent, stride, strideInBytes, baseLb); |
| 812 | + bounds.push_back(bound); |
| 813 | + ++dimension; |
| 814 | + } |
| 815 | + } |
| 816 | + return bounds; |
| 817 | +} |
| 818 | + |
| 819 | +template <typename ObjectType, typename BoundsOp, typename BoundsType> |
| 820 | +mlir::Value gatherDataOperandAddrAndBounds( |
| 821 | + Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder, |
| 822 | + Fortran::semantics::SemanticsContext &semanticsContext, |
| 823 | + Fortran::lower::StatementContext &stmtCtx, const ObjectType &object, |
| 824 | + mlir::Location operandLocation, std::stringstream &asFortran, |
| 825 | + llvm::SmallVector<mlir::Value> &bounds) { |
| 826 | + mlir::Value baseAddr; |
| 827 | + |
| 828 | + std::visit( |
| 829 | + Fortran::common::visitors{ |
| 830 | + [&](const Fortran::parser::Designator &designator) { |
| 831 | + if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext, |
| 832 | + designator)}) { |
| 833 | + if ((*expr).Rank() > 0 && |
| 834 | + Fortran::parser::Unwrap<Fortran::parser::ArrayElement>( |
| 835 | + designator)) { |
| 836 | + const auto *arrayElement = |
| 837 | + Fortran::parser::Unwrap<Fortran::parser::ArrayElement>( |
| 838 | + designator); |
| 839 | + const auto *dataRef = |
| 840 | + std::get_if<Fortran::parser::DataRef>(&designator.u); |
| 841 | + fir::ExtendedValue dataExv; |
| 842 | + if (Fortran::parser::Unwrap< |
| 843 | + Fortran::parser::StructureComponent>( |
| 844 | + arrayElement->base)) { |
| 845 | + auto exprBase = Fortran::semantics::AnalyzeExpr( |
| 846 | + semanticsContext, arrayElement->base); |
| 847 | + dataExv = converter.genExprAddr(operandLocation, *exprBase, |
| 848 | + stmtCtx); |
| 849 | + baseAddr = fir::getBase(dataExv); |
| 850 | + asFortran << (*exprBase).AsFortran(); |
| 851 | + } else { |
| 852 | + const Fortran::parser::Name &name = |
| 853 | + Fortran::parser::GetLastName(*dataRef); |
| 854 | + baseAddr = getDataOperandBaseAddr( |
| 855 | + converter, builder, *name.symbol, operandLocation); |
| 856 | + dataExv = converter.getSymbolExtendedValue(*name.symbol); |
| 857 | + asFortran << name.ToString(); |
| 858 | + } |
| 859 | + |
| 860 | + if (!arrayElement->subscripts.empty()) { |
| 861 | + asFortran << '('; |
| 862 | + bounds = genBoundsOps<BoundsType, BoundsOp>( |
| 863 | + builder, operandLocation, converter, stmtCtx, |
| 864 | + arrayElement->subscripts, asFortran, dataExv, baseAddr); |
| 865 | + } |
| 866 | + asFortran << ')'; |
| 867 | + } else if (Fortran::parser::Unwrap< |
| 868 | + Fortran::parser::StructureComponent>(designator)) { |
| 869 | + fir::ExtendedValue compExv = |
| 870 | + converter.genExprAddr(operandLocation, *expr, stmtCtx); |
| 871 | + baseAddr = fir::getBase(compExv); |
| 872 | + if (fir::unwrapRefType(baseAddr.getType()) |
| 873 | + .isa<fir::SequenceType>()) |
| 874 | + bounds = genBaseBoundsOps<BoundsType, BoundsOp>( |
| 875 | + builder, operandLocation, converter, compExv, baseAddr); |
| 876 | + asFortran << (*expr).AsFortran(); |
| 877 | + |
| 878 | + // If the component is an allocatable or pointer the result of |
| 879 | + // genExprAddr will be the result of a fir.box_addr operation. |
| 880 | + // Retrieve the box so we handle it like other descriptor. |
| 881 | + if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>( |
| 882 | + baseAddr.getDefiningOp())) { |
| 883 | + baseAddr = boxAddrOp.getVal(); |
| 884 | + bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>( |
| 885 | + builder, operandLocation, converter, compExv, baseAddr); |
| 886 | + } |
| 887 | + } else { |
| 888 | + // Scalar or full array. |
| 889 | + if (const auto *dataRef{ |
| 890 | + std::get_if<Fortran::parser::DataRef>(&designator.u)}) { |
| 891 | + const Fortran::parser::Name &name = |
| 892 | + Fortran::parser::GetLastName(*dataRef); |
| 893 | + fir::ExtendedValue dataExv = |
| 894 | + converter.getSymbolExtendedValue(*name.symbol); |
| 895 | + baseAddr = getDataOperandBaseAddr( |
| 896 | + converter, builder, *name.symbol, operandLocation); |
| 897 | + if (fir::unwrapRefType(baseAddr.getType()) |
| 898 | + .isa<fir::BaseBoxType>()) |
| 899 | + bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>( |
| 900 | + builder, operandLocation, converter, dataExv, baseAddr); |
| 901 | + if (fir::unwrapRefType(baseAddr.getType()) |
| 902 | + .isa<fir::SequenceType>()) |
| 903 | + bounds = genBaseBoundsOps<BoundsType, BoundsOp>( |
| 904 | + builder, operandLocation, converter, dataExv, baseAddr); |
| 905 | + asFortran << name.ToString(); |
| 906 | + } else { // Unsupported |
| 907 | + llvm::report_fatal_error( |
| 908 | + "Unsupported type of OpenACC operand"); |
| 909 | + } |
| 910 | + } |
| 911 | + } |
| 912 | + }, |
| 913 | + [&](const Fortran::parser::Name &name) { |
| 914 | + baseAddr = getDataOperandBaseAddr(converter, builder, *name.symbol, |
| 915 | + operandLocation); |
| 916 | + asFortran << name.ToString(); |
| 917 | + }}, |
| 918 | + object.u); |
| 919 | + return baseAddr; |
| 920 | +} |
| 921 | + |
614 | 922 | } // namespace lower
|
615 | 923 | } // namespace Fortran
|
616 | 924 |
|
|
0 commit comments