Skip to content

Commit 4be76ec

Browse files
committed
[WIP][flang][OpenMP] Experimental pass to map do concurrent to OMP
Adds a pass to map `do concurrent` to OpenMP worksharing consturcts. For now, only maps basic loops to `omp parallel do`. This is still a WIP with more work needed for testing and mapping more advanced loops.
1 parent 2642240 commit 4be76ec

File tree

6 files changed

+246
-1
lines changed

6 files changed

+246
-1
lines changed

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments,
7575
func.func @foo(%arg0: !fir.ref<!fir.array<?x?x!fir.char<1,?>>>, %arg1: !fir.ref<i64>) {
7676
%c10 = arith.constant 10 : index
7777
%c20 = arith.constant 20 : index
78-
%1 = fir.load %ag1 : fir.ref<i64>
78+
%1 = fir.load %arg1 : fir.ref<i64>
7979
%2 = fir.shape_shift %c10, %1, %c20, %1 : (index, index, index, index) -> !fir.shapeshift<2>
8080
%3 = hfir.declare %arg0(%2) typeparams %1 {uniq_name = "c"} (fir.ref<!fir.array<?x?x!fir.char<1,?>>>, fir.shapeshift<2>, index) -> (fir.box<!fir.array<?x?x!fir.char<1,?>>>, fir.ref<!fir.array<?x?x!fir.char<1,?>>>)
8181
// ... uses %3#0 as "c"

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ std::unique_ptr<mlir::Pass> createFunctionAttrPass();
9393
std::unique_ptr<mlir::Pass>
9494
createFunctionAttrPass(FunctionAttrTypes &functionAttr);
9595

96+
std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass();
97+
9698
// declarative passes
9799
#define GEN_PASS_REGISTRATION
98100
#include "flang/Optimizer/Transforms/Passes.h.inc"

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,24 @@ def FunctionAttr : Pass<"function-attr", "mlir::func::FuncOp"> {
370370
let constructor = "::fir::createFunctionAttrPass()";
371371
}
372372

373+
def DoConcurrentConversionPass : Pass<"fopenmp-do-concurrent-conversion", "mlir::func::FuncOp"> {
374+
let summary = "Map `DO CONCURRENT` loops to OpenMP worksharing loops.";
375+
376+
let description = [{ This is an experimental pass to map `DO CONCURRENR` loops
377+
to their correspnding equivalent OpenMP worksharing constructs.
378+
379+
For now the following is supported:
380+
- Mapping simple loops to `parallel do`.
381+
382+
Still to TODO:
383+
- More extensive testing.
384+
- Mapping to `target teams distribute parallel do`.
385+
- Allowing the user to control mapping behavior: either to the host or
386+
target.
387+
}];
388+
389+
let constructor = "::fir::createDoConcurrentConversionPass()";
390+
let dependentDialects = ["mlir::omp::OpenMPDialect"];
391+
}
392+
373393
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_flang_library(FIRTransforms
2121
OMPMarkDeclareTarget.cpp
2222
VScaleAttr.cpp
2323
FunctionAttr.cpp
24+
DoConcurrentConversion.cpp
2425

2526
DEPENDS
2627
FIRDialect
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
//===- DoConcurrentConversion.cpp -- map `DO CONCURRENT` to OpenMP loops --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Dialect/FIRDialect.h"
10+
#include "flang/Optimizer/Dialect/FIROps.h"
11+
#include "flang/Optimizer/Dialect/FIRType.h"
12+
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
13+
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
14+
#include "flang/Optimizer/Transforms/Passes.h"
15+
#include "mlir/Dialect/Func/IR/FuncOps.h"
16+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17+
#include "mlir/IR/Diagnostics.h"
18+
#include "mlir/IR/IRMapping.h"
19+
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
22+
#include <memory>
23+
24+
namespace fir {
25+
#define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS
26+
#include "flang/Optimizer/Transforms/Passes.h.inc"
27+
} // namespace fir
28+
29+
#define DEBUG_TYPE "fopenmp-do-concurrent-conversion"
30+
31+
namespace {
32+
class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
33+
public:
34+
using mlir::OpConversionPattern<fir::DoLoopOp>::OpConversionPattern;
35+
36+
mlir::LogicalResult
37+
matchAndRewrite(fir::DoLoopOp doLoop, OpAdaptor adaptor,
38+
mlir::ConversionPatternRewriter &rewriter) const override {
39+
mlir::OpPrintingFlags flags;
40+
flags.printGenericOpForm();
41+
42+
mlir::omp::ParallelOp parallelOp =
43+
rewriter.create<mlir::omp::ParallelOp>(doLoop.getLoc());
44+
45+
rewriter.createBlock(&parallelOp.getRegion());
46+
mlir::Block &block = parallelOp.getRegion().back();
47+
48+
rewriter.setInsertionPointToEnd(&block);
49+
rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc());
50+
51+
rewriter.setInsertionPointToStart(&block);
52+
53+
// Clone the LB, UB, step defining ops inside the parallel region.
54+
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step;
55+
lowerBound.push_back(
56+
rewriter.clone(*doLoop.getLowerBound().getDefiningOp())->getResult(0));
57+
upperBound.push_back(
58+
rewriter.clone(*doLoop.getUpperBound().getDefiningOp())->getResult(0));
59+
step.push_back(
60+
rewriter.clone(*doLoop.getStep().getDefiningOp())->getResult(0));
61+
62+
auto wsLoopOp = rewriter.create<mlir::omp::WsLoopOp>(
63+
doLoop.getLoc(), lowerBound, upperBound, step);
64+
wsLoopOp.setInclusive(true);
65+
66+
auto outlineableOp =
67+
mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(*parallelOp);
68+
assert(outlineableOp);
69+
rewriter.setInsertionPointToStart(outlineableOp.getAllocaBlock());
70+
71+
// For the induction variable, we need to privative its allocation and
72+
// binding inside the parallel region.
73+
llvm::SmallSetVector<mlir::Operation *, 2> workList;
74+
// Therefore, we first discover the induction variable by discovering
75+
// `fir.store`s where the source is the loop's block argument.
76+
workList.insert(doLoop.getInductionVar().getUsers().begin(),
77+
doLoop.getInductionVar().getUsers().end());
78+
llvm::SmallSetVector<fir::StoreOp, 2> inductionVarTargetStores;
79+
80+
// Walk the def-chain of the loop's block argument until we hit `fir.store`.
81+
while (!workList.empty()) {
82+
mlir::Operation *item = workList.front();
83+
84+
if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(item)) {
85+
inductionVarTargetStores.insert(storeOp);
86+
} else {
87+
workList.insert(item->getUsers().begin(), item->getUsers().end());
88+
}
89+
90+
workList.remove(item);
91+
}
92+
93+
// For each collected `fir.sotre`, find the target memref's alloca's and
94+
// declare ops.
95+
llvm::SmallSetVector<mlir::Operation *, 4> declareAndAllocasToClone;
96+
for (auto storeOp : inductionVarTargetStores) {
97+
mlir::Operation *storeTarget = storeOp.getMemref().getDefiningOp();
98+
99+
for (auto operand : storeTarget->getOperands()) {
100+
declareAndAllocasToClone.insert(operand.getDefiningOp());
101+
}
102+
declareAndAllocasToClone.insert(storeTarget);
103+
}
104+
105+
mlir::IRMapping mapper;
106+
107+
// Collect the memref defining ops in the parallel region.
108+
for (mlir::Operation *opToClone : declareAndAllocasToClone) {
109+
rewriter.clone(*opToClone, mapper);
110+
}
111+
112+
// Clone the loop's body inside the worksharing construct using the mapped
113+
// memref values.
114+
rewriter.cloneRegionBefore(doLoop.getRegion(), wsLoopOp.getRegion(),
115+
wsLoopOp.getRegion().begin(), mapper);
116+
117+
mlir::Operation *terminator = wsLoopOp.getRegion().back().getTerminator();
118+
rewriter.setInsertionPointToEnd(&wsLoopOp.getRegion().back());
119+
rewriter.create<mlir::omp::YieldOp>(terminator->getLoc());
120+
rewriter.eraseOp(terminator);
121+
122+
rewriter.eraseOp(doLoop);
123+
124+
return mlir::success();
125+
}
126+
};
127+
128+
class DoConcurrentConversionPass
129+
: public fir::impl::DoConcurrentConversionPassBase<
130+
DoConcurrentConversionPass> {
131+
public:
132+
void runOnOperation() override {
133+
mlir::func::FuncOp func = getOperation();
134+
135+
if (func.isDeclaration()) {
136+
return;
137+
}
138+
139+
auto *context = &getContext();
140+
mlir::RewritePatternSet patterns(context);
141+
patterns.insert<DoConcurrentConversion>(context);
142+
mlir::ConversionTarget target(*context);
143+
target.addLegalDialect<fir::FIROpsDialect, hlfir::hlfirDialect,
144+
mlir::arith::ArithDialect, mlir::func::FuncDialect,
145+
mlir::omp::OpenMPDialect>();
146+
147+
target.addDynamicallyLegalOp<fir::DoLoopOp>(
148+
[](fir::DoLoopOp op) { return !op.getUnordered(); });
149+
150+
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
151+
std::move(patterns)))) {
152+
mlir::emitError(mlir::UnknownLoc::get(context),
153+
"error in converting do-concurrent op");
154+
signalPassFailure();
155+
}
156+
}
157+
};
158+
} // namespace
159+
160+
std::unique_ptr<mlir::Pass> fir::createDoConcurrentConversionPass() {
161+
return std::make_unique<DoConcurrentConversionPass>();
162+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Tests mapping of a basic `do concurrent` loop to `!$omp parallel do`.
2+
3+
// RUN: fir-opt --fopenmp-do-concurrent-conversion %s | FileCheck %s
4+
5+
// CHECK-LABEL: func.func @do_concurrent_basic
6+
func.func @do_concurrent_basic() attributes {fir.bindc_name = "do_concurrent_basic"} {
7+
// CHECK: %[[ARR:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
8+
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
9+
// CHECK: %[[C10:.*]] = arith.constant 10 : i32
10+
11+
%0 = fir.alloca i32 {bindc_name = "i"}
12+
%1:2 = hlfir.declare %0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
13+
%2 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xi32>>
14+
%c10 = arith.constant 10 : index
15+
%3 = fir.shape %c10 : (index) -> !fir.shape<1>
16+
%4:2 = hlfir.declare %2(%3) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
17+
%c1_i32 = arith.constant 1 : i32
18+
%7 = fir.convert %c1_i32 : (i32) -> index
19+
%c10_i32 = arith.constant 10 : i32
20+
%8 = fir.convert %c10_i32 : (i32) -> index
21+
%c1 = arith.constant 1 : index
22+
23+
// CHECK-NOT: fir.do_loop
24+
25+
// CHECK: omp.parallel {
26+
27+
// CHECK-NEXT: %[[ITER_VAR:.*]] = fir.alloca i32 {bindc_name = "i"}
28+
// CHECK-NEXT: %[[BINDING:.*]]:2 = hlfir.declare %[[ITER_VAR]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
29+
30+
// CHECK: %[[LB:.*]] = fir.convert %[[C1]] : (i32) -> index
31+
// CHECK: %[[UB:.*]] = fir.convert %[[C10]] : (i32) -> index
32+
// CHECK: %[[STEP:.*]] = arith.constant 1 : index
33+
34+
// CHECK: omp.wsloop for (%[[ARG0:.*]]) : index = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
35+
// CHECK-NEXT: %[[IV_IDX:.*]] = fir.convert %[[ARG0]] : (index) -> i32
36+
// CHECK-NEXT: fir.store %[[IV_IDX]] to %[[BINDING]]#1 : !fir.ref<i32>
37+
// CHECK-NEXT: %[[IV_VAL1:.*]] = fir.load %[[BINDING]]#0 : !fir.ref<i32>
38+
// CHECK-NEXT: %[[IV_VAL2:.*]] = fir.load %[[BINDING]]#0 : !fir.ref<i32>
39+
// CHECK-NEXT: %[[IV_VAL_I64:.*]] = fir.convert %[[IV_VAL2]] : (i32) -> i64
40+
// CHECK-NEXT: %[[ARR_ACCESS:.*]] = hlfir.designate %[[ARR]]#0 (%[[IV_VAL_I64]]) : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
41+
// CHECK-NEXT: hlfir.assign %[[IV_VAL1]] to %[[ARR_ACCESS]] : i32, !fir.ref<i32>
42+
// CHECK-NEXT: omp.yield
43+
// CHECK-NEXT: }
44+
45+
// CHECK-NEXT: omp.terminator
46+
// CHECK-NEXT: }
47+
fir.do_loop %arg0 = %7 to %8 step %c1 unordered {
48+
%13 = fir.convert %arg0 : (index) -> i32
49+
fir.store %13 to %1#1 : !fir.ref<i32>
50+
%14 = fir.load %1#0 : !fir.ref<i32>
51+
%15 = fir.load %1#0 : !fir.ref<i32>
52+
%16 = fir.convert %15 : (i32) -> i64
53+
%17 = hlfir.designate %4#0 (%16) : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
54+
hlfir.assign %14 to %17 : i32, !fir.ref<i32>
55+
}
56+
57+
// CHECK-NOT: fir.do_loop
58+
59+
return
60+
}

0 commit comments

Comments
 (0)