Skip to content

[Flang][MLIR] Add !$omp unroll and omp.unroll_heuristic #144785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 158 additions & 6 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,161 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return loopOp;
}

static mlir::omp::CanonicalLoopOp
genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval, mlir::Location loc,
const ConstructQueue &queue,
ConstructQueue::const_iterator item,
llvm::ArrayRef<const semantics::Symbol *> ivs,
llvm::omp::Directive directive, DataSharingProcessor &dsp) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

assert(ivs.size() == 1 && "Nested loops not yet implemented");
const semantics::Symbol *iv = ivs[0];

auto &nestedEval = eval.getFirstNestedEvaluation();
if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) {
TODO(loc, "Do Concurrent in unroll construct");
}

// Get the loop bounds (and increment)
auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
assert(doStmt && "Expected do loop to be in the nested evaluation");
auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
assert(loopControl.has_value());
auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds for canonical loop");
lower::StatementContext stmtCtx;
mlir::Value loopLBVar = fir::getBase(
converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
mlir::Value loopUBVar = fir::getBase(
converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
mlir::Value loopStepVar = [&]() {
if (bounds->step) {
return fir::getBase(
converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
} else {
// If `step` is not present, assume it is `1`.
return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(),
1);
}
}();

// Get the integer kind for the loop variable and cast the loop bounds
size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);

// Start lowering
mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);

// Ensure we are counting upwards. If not, negate step and swap lb and ub.
mlir::Value negStep =
firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar);
mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>(
loc, isDownwards, negStep, loopStepVar);
mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>(
loc, isDownwards, loopUBVar, loopLBVar);
mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>(
loc, isDownwards, loopLBVar, loopUBVar);

// Compute the trip count assuming lb <= ub. This guarantees that the result
// is non-negative and we can use unsigned arithmetic.
mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>(
loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
mlir::Value tcMinusOne =
firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr);
mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>(
loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);

// Fall back to 0 if lb > ub
mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, ub, lb);
mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>(
loc, isZeroTC, zero, tcIfLooping);

// Create the CLI handle.
auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
mlir::Value cli = newcli.getResult();

auto ivCallback = [&](mlir::Operation *op)
-> llvm::SmallVector<const Fortran::semantics::Symbol *> {
mlir::Region &region = op->getRegion(0);

// Create the op's region skeleton (BB taking the iv as argument)
firOpBuilder.createBlock(&region, {}, {loopVarType}, {loc});

// Compute the value of the loop variable from the logical iteration number.
mlir::Value natIterNum = fir::getBase(region.front().getArgument(0));
mlir::Value scaled =
firOpBuilder.create<mlir::arith::MulIOp>(loc, natIterNum, loopStepVar);
mlir::Value userVal =
firOpBuilder.create<mlir::arith::AddIOp>(loc, loopLBVar, scaled);

// The argument is not currently in memory, so make a temporary for the
// argument, and store it there, then bind that location to the argument.
mlir::Operation *storeOp =
createAndSetPrivatizedLoopVar(converter, loc, userVal, iv);

firOpBuilder.setInsertionPointAfter(storeOp);
return {iv};
};

// Create the omp.canonical_loop operation
auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
directive)
.setClauses(&item->clauses)
.setDataSharingProcessor(&dsp)
.setGenRegionEntryCb(ivCallback),
queue, item, tripcount, cli);

firOpBuilder.setInsertionPointAfter(canonLoop);
return canonLoop;
}

static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
lower::StatementContext &stmtCtx,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

mlir::omp::LoopRelatedClauseOps loopInfo;
llvm::SmallVector<const semantics::Symbol *> iv;
collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv);

// Clauses for unrolling not yet implemnted
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processTODO<clause::Partial, clause::Full>(
loc, llvm::omp::Directive::OMPD_unroll);

// Even though unroll does not support data-sharing clauses, but this is
// required to fill the symbol table.
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
/*useDelayedPrivatization=*/false, symTable);
dsp.processStep1();

// Emit the associated loop
auto canonLoop =
genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item,
iv, llvm::omp::Directive::OMPD_unroll, dsp);

// Apply unrolling to it
auto cli = canonLoop.getCli();
firOpBuilder.create<mlir::omp::UnrollHeuristicOp>(loc, cli);
}

static mlir::omp::MaskedOp
genMaskedOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
lower::StatementContext &stmtCtx,
Expand Down Expand Up @@ -3516,12 +3671,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
item);
break;
case llvm::omp::Directive::OMPD_tile:
case llvm::omp::Directive::OMPD_unroll: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
TODO(loc, "Unhandled loop directive (" +
llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
}
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about OMPD_tile?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR intends to only implement one of the simplest directives for now: omp unroll. As outlined in the summary, full feature support requires significantly more work. I was working on it, but was aleady taking forever.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry my comment was pretty ambiguous.

The old case handling for tile (to generate the "unhandled loop directive" NYI error, has been removed. I expect that will now lead to an assertion failure in the default case.

// case llvm::omp::Directive::OMPD_workdistribute:
case llvm::omp::Directive::OMPD_workshare:
newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
Expand Down
39 changes: 39 additions & 0 deletions flang/test/Lower/OpenMP/unroll-heuristic01.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s


subroutine omp_unroll_heuristic01(lb, ub, inc)
integer res, i, lb, ub, inc

!$omp unroll
do i = lb, ub, inc
res = i
end do
!$omp end unroll

end subroutine omp_unroll_heuristic01


!CHECK-LABEL: func.func @_QPomp_unroll_heuristic01(
!CHECK: %c0_i32 = arith.constant 0 : i32
!CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
!CHECK-NEXT: %13 = arith.cmpi slt, %12, %c0_i32 : i32
!CHECK-NEXT: %14 = arith.subi %c0_i32, %12 : i32
!CHECK-NEXT: %15 = arith.select %13, %14, %12 : i32
!CHECK-NEXT: %16 = arith.select %13, %11, %10 : i32
!CHECK-NEXT: %17 = arith.select %13, %10, %11 : i32
!CHECK-NEXT: %18 = arith.subi %17, %16 overflow<nuw> : i32
!CHECK-NEXT: %19 = arith.divui %18, %15 : i32
!CHECK-NEXT: %20 = arith.addi %19, %c1_i32 overflow<nuw> : i32
!CHECK-NEXT: %21 = arith.cmpi slt, %17, %16 : i32
!CHECK-NEXT: %22 = arith.select %21, %c0_i32, %20 : i32
!CHECK-NEXT: %canonloop_s0 = omp.new_cli
!CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%22) {
!CHECK-NEXT: %23 = arith.muli %iv, %12 : i32
!CHECK-NEXT: %24 = arith.addi %10, %23 : i32
!CHECK-NEXT: hlfir.assign %24 to %9#0 : i32, !fir.ref<i32>
!CHECK-NEXT: %25 = fir.load %9#0 : !fir.ref<i32>
!CHECK-NEXT: hlfir.assign %25 to %6#0 : i32, !fir.ref<i32>
!CHECK-NEXT: omp.terminator
!CHECK-NEXT: }
!CHECK-NEXT: omp.unroll_heuristic(%canonloop_s0)
!CHECK-NEXT: return
70 changes: 70 additions & 0 deletions flang/test/Lower/OpenMP/unroll-heuristic02.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s


subroutine omp_unroll_heuristic_nested02(outer_lb, outer_ub, outer_inc, inner_lb, inner_ub, inner_inc)
integer res, i, j, inner_lb, inner_ub, inner_inc, outer_lb, outer_ub, outer_inc

!$omp unroll
do i = outer_lb, outer_ub, outer_inc
!$omp unroll
do j = inner_lb, inner_ub, inner_inc
res = i + j
end do
!$omp end unroll
end do
!$omp end unroll

end subroutine omp_unroll_heuristic_nested02


!CHECK-LABEL: func.func @_QPomp_unroll_heuristic_nested02(%arg0: !fir.ref<i32> {fir.bindc_name = "outer_lb"}, %arg1: !fir.ref<i32> {fir.bindc_name = "outer_ub"}, %arg2: !fir.ref<i32> {fir.bindc_name = "outer_inc"}, %arg3: !fir.ref<i32> {fir.bindc_name = "inner_lb"}, %arg4: !fir.ref<i32> {fir.bindc_name = "inner_ub"}, %arg5: !fir.ref<i32> {fir.bindc_name = "inner_inc"}) {
!CHECK: %c0_i32 = arith.constant 0 : i32
!CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
!CHECK-NEXT: %18 = arith.cmpi slt, %17, %c0_i32 : i32
!CHECK-NEXT: %19 = arith.subi %c0_i32, %17 : i32
!CHECK-NEXT: %20 = arith.select %18, %19, %17 : i32
!CHECK-NEXT: %21 = arith.select %18, %16, %15 : i32
!CHECK-NEXT: %22 = arith.select %18, %15, %16 : i32
!CHECK-NEXT: %23 = arith.subi %22, %21 overflow<nuw> : i32
!CHECK-NEXT: %24 = arith.divui %23, %20 : i32
!CHECK-NEXT: %25 = arith.addi %24, %c1_i32 overflow<nuw> : i32
!CHECK-NEXT: %26 = arith.cmpi slt, %22, %21 : i32
!CHECK-NEXT: %27 = arith.select %26, %c0_i32, %25 : i32
!CHECK-NEXT: %canonloop_s0 = omp.new_cli
!CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%27) {
!CHECK-NEXT: %28 = arith.muli %iv, %17 : i32
!CHECK-NEXT: %29 = arith.addi %15, %28 : i32
!CHECK-NEXT: hlfir.assign %29 to %14#0 : i32, !fir.ref<i32>
!CHECK-NEXT: %30 = fir.alloca i32 {bindc_name = "j", pinned, uniq_name = "_QFomp_unroll_heuristic_nested02Ej"}
!CHECK-NEXT: %31:2 = hlfir.declare %30 {uniq_name = "_QFomp_unroll_heuristic_nested02Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK-NEXT: %32 = fir.load %4#0 : !fir.ref<i32>
!CHECK-NEXT: %33 = fir.load %5#0 : !fir.ref<i32>
!CHECK-NEXT: %34 = fir.load %3#0 : !fir.ref<i32>
!CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32
!CHECK-NEXT: %c1_i32_1 = arith.constant 1 : i32
!CHECK-NEXT: %35 = arith.cmpi slt, %34, %c0_i32_0 : i32
!CHECK-NEXT: %36 = arith.subi %c0_i32_0, %34 : i32
!CHECK-NEXT: %37 = arith.select %35, %36, %34 : i32
!CHECK-NEXT: %38 = arith.select %35, %33, %32 : i32
!CHECK-NEXT: %39 = arith.select %35, %32, %33 : i32
!CHECK-NEXT: %40 = arith.subi %39, %38 overflow<nuw> : i32
!CHECK-NEXT: %41 = arith.divui %40, %37 : i32
!CHECK-NEXT: %42 = arith.addi %41, %c1_i32_1 overflow<nuw> : i32
!CHECK-NEXT: %43 = arith.cmpi slt, %39, %38 : i32
!CHECK-NEXT: %44 = arith.select %43, %c0_i32_0, %42 : i32
!CHECK-NEXT: %canonloop_s0_s0 = omp.new_cli
!CHECK-NEXT: omp.canonical_loop(%canonloop_s0_s0) %iv_2 : i32 in range(%44) {
!CHECK-NEXT: %45 = arith.muli %iv_2, %34 : i32
!CHECK-NEXT: %46 = arith.addi %32, %45 : i32
!CHECK-NEXT: hlfir.assign %46 to %31#0 : i32, !fir.ref<i32>
!CHECK-NEXT: %47 = fir.load %14#0 : !fir.ref<i32>
!CHECK-NEXT: %48 = fir.load %31#0 : !fir.ref<i32>
!CHECK-NEXT: %49 = arith.addi %47, %48 : i32
!CHECK-NEXT: hlfir.assign %49 to %12#0 : i32, !fir.ref<i32>
!CHECK-NEXT: omp.terminator
!CHECK-NEXT: }
!CHECK-NEXT: omp.unroll_heuristic(%canonloop_s0_s0)
!CHECK-NEXT: omp.terminator
!CHECK-NEXT: }
!CHECK-NEXT: omp.unroll_heuristic(%canonloop_s0)
!CHECK-NEXT: return
43 changes: 43 additions & 0 deletions flang/test/Parser/OpenMP/unroll-heuristic.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
! RUN: %flang_fc1 -fopenmp -fopenmp-version=51 %s -fdebug-unparse | FileCheck --check-prefix=UNPARSE %s
! RUN: %flang_fc1 -fopenmp -fopenmp-version=51 %s -fdebug-dump-parse-tree | FileCheck --check-prefix=PTREE %s

subroutine openmp_parse_unroll_heuristic
integer i

!$omp unroll
do i = 1, 100
call func(i)
end do
!$omp end unroll
END subroutine openmp_parse_unroll_heuristic


!UNPARSE: !$OMP UNROLL
!UNPARSE-NEXT: DO i=1_4,100_4
!UNPARSE-NEXT: CALL func(i)
!UNPARSE-NEXT: END DO
!UNPARSE-NEXT: !$OMP END UNROLL

!PTREE: OpenMPConstruct -> OpenMPLoopConstruct
!PTREE-NEXT: | OmpBeginLoopDirective
!PTREE-NEXT: | | OmpLoopDirective -> llvm::omp::Directive = unroll
!PTREE-NEXT: | | OmpClauseList ->
!PTREE-NEXT: | DoConstruct
!PTREE-NEXT: | | NonLabelDoStmt
!PTREE-NEXT: | | | LoopControl -> LoopBounds
!PTREE-NEXT: | | | | Scalar -> Name = 'i'
!PTREE-NEXT: | | | | Scalar -> Expr = '1_4'
!PTREE-NEXT: | | | | | LiteralConstant -> IntLiteralConstant = '1'
!PTREE-NEXT: | | | | Scalar -> Expr = '100_4'
!PTREE-NEXT: | | | | | LiteralConstant -> IntLiteralConstant = '100'
!PTREE-NEXT: | | Block
!PTREE-NEXT: | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> CallStmt = 'CALL func(i)'
!PTREE-NEXT: | | | | | | Call
!PTREE-NEXT: | | | | | ProcedureDesignator -> Name = 'func'
!PTREE-NEXT: | | | | | ActualArgSpec
!PTREE-NEXT: | | | | | | ActualArg -> Expr = 'i'
!PTREE-NEXT: | | | | | | | Designator -> DataRef -> Name = 'i'
!PTREE-NEXT: | | EndDoStmt ->
!PTREE-NEXT: | OmpEndLoopDirective
!PTREE-NEXT: | | OmpLoopDirective -> llvm::omp::Directive = unroll
!PTREE-NEXT: | | OmpClauseList ->
6 changes: 1 addition & 5 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
#ifndef MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
#define MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_

#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"

#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"

#include "mlir/Dialect/OpenMP/OpenMPClauseOps.h.inc"

namespace mlir {
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
Expand All @@ -24,6 +25,11 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"

namespace mlir::omp {
/// Find the omp.new_cli, generator, and consumer of a canonical loop info.
std::tuple<NewCliOp, OpOperand *, OpOperand *> decodeCli(mlir::Value cli);
} // namespace mlir::omp

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc"

Expand All @@ -33,8 +39,6 @@

#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"

#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"

#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_OPENMP_OPENMPINTERFACES_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
Expand Down
Loading
Loading