Skip to content

Commit 209c5db

Browse files
author
Razvan Lupusoru
committed
[acc] Improve LegalizeDataValues pass to handle data constructs
1 parent 03dcd88 commit 209c5db

File tree

5 files changed

+66
-26
lines changed

5 files changed

+66
-26
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@
5656
mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp
5757
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
5858
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
59-
#define OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS \
59+
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
6060
mlir::acc::DataOp, mlir::acc::DeclareOp
6161
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
6262
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
6363
mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \
6464
mlir::acc::DeclareExitOp
6565
#define ACC_DATA_CONSTRUCT_OPS \
66-
OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
66+
ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
6767
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \
6868
ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS
6969
#define ACC_COMPUTE_LOOP_AND_DATA_CONSTRUCT_OPS \

mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
#include "mlir/Pass/Pass.h"
1313

14-
#define GEN_PASS_DECL
15-
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
16-
1714
namespace mlir {
1815

1916
namespace func {
@@ -22,8 +19,8 @@ class FuncOp;
2219

2320
namespace acc {
2421

25-
/// Create a pass to replace ssa values in region with device/host values.
26-
std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
22+
#define GEN_PASS_DECL
23+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
2724

2825
/// Generate the code for registering conversion passes.
2926
#define GEN_PASS_REGISTRATION

mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
15-
let summary = "Legalize the data in the compute region";
14+
def LegalizeDataValuesInRegion : Pass<"openacc-legalize-data-values", "mlir::func::FuncOp"> {
15+
let summary = "Legalizes SSA values in compute regions with results from data clause operations";
1616
let description = [{
17-
This pass replace uses of varPtr in the compute region with their accPtr
18-
gathered from the data clause operands.
17+
This pass replace uses of the `varPtr` in compute regions (kernels,
18+
parallel, serial) with the result of data clause operations (`accPtr`).
1919
}];
2020
let options = [
2121
Option<"hostToDevice", "host-to-device", "bool", "true",
2222
"Replace varPtr uses with accPtr if true. Replace accPtr uses with "
23-
"varPtr if false">
23+
"varPtr if false">,
24+
Option<"applyToAccDataConstruct", "apply-to-acc-data-construct", "bool", "true",
25+
"Replaces varPtr uses with accPtr for acc compute regions contained "
26+
"within acc.data or acc.declare region.">
2427
];
25-
let constructor = "::mlir::acc::createLegalizeDataInRegion()";
2628
}
2729

2830
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES

mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_mlir_dialect_library(MLIROpenACCTransforms
2-
LegalizeData.cpp
2+
LegalizeDataValues.cpp
33

44
ADDITIONAL_HEADER_DIRS
55
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC

mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp renamed to mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- LegalizeData.cpp - -------------------------------------------------===//
1+
//===- LegalizeDataValues.cpp - -------------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,10 +12,11 @@
1212
#include "mlir/Dialect/OpenACC/OpenACC.h"
1313
#include "mlir/Pass/Pass.h"
1414
#include "mlir/Transforms/RegionUtils.h"
15+
#include "llvm/Support/ErrorHandling.h"
1516

1617
namespace mlir {
1718
namespace acc {
18-
#define GEN_PASS_DEF_LEGALIZEDATAINREGION
19+
#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
1920
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
2021
} // namespace acc
2122
} // namespace mlir
@@ -24,6 +25,17 @@ using namespace mlir;
2425

2526
namespace {
2627

28+
static bool insideAccComputeRegion(mlir::Operation *op) {
29+
mlir::Operation *parent{op->getParentOp()};
30+
while (parent) {
31+
if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
32+
return true;
33+
}
34+
parent = parent->getParentOp();
35+
}
36+
return false;
37+
}
38+
2739
static void collectPtrs(mlir::ValueRange operands,
2840
llvm::SmallVector<std::pair<Value, Value>> &values,
2941
bool hostToDevice) {
@@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
3951
}
4052
}
4153

54+
template <typename Op>
55+
static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
56+
Region &outerRegion) {
57+
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
58+
if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
59+
if constexpr (std::is_same_v<Op, acc::DataOp> ||
60+
std::is_same_v<Op, acc::DeclareOp>) {
61+
// For data construct regions, only replace uses in contained compute
62+
// regions.
63+
if (insideAccComputeRegion(use.getOwner())) {
64+
use.set(replacement);
65+
}
66+
} else {
67+
use.set(replacement);
68+
}
69+
}
70+
}
71+
}
72+
4273
template <typename Op>
4374
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
4475
llvm::SmallVector<std::pair<Value, Value>> values;
@@ -48,26 +79,35 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
4879
collectPtrs(op.getPrivateOperands(), values, hostToDevice);
4980
} else {
5081
collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
51-
if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
82+
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83+
!std::is_same_v<Op, acc::DataOp> &&
84+
!std::is_same_v<Op, acc::DeclareOp>) {
5285
collectPtrs(op.getReductionOperands(), values, hostToDevice);
5386
collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
5487
collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
5588
}
5689
}
5790

5891
for (auto p : values)
59-
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
92+
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
93+
op.getRegion());
6094
}
6195

62-
struct LegalizeDataInRegion
63-
: public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
96+
class LegalizeDataValuesInRegion
97+
: public acc::impl::LegalizeDataValuesInRegionBase<
98+
LegalizeDataValuesInRegion> {
99+
public:
100+
using LegalizeDataValuesInRegionBase<
101+
LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
64102

65103
void runOnOperation() override {
66104
func::FuncOp funcOp = getOperation();
67105
bool replaceHostVsDevice = this->hostToDevice.getValue();
68106

69107
funcOp.walk([&](Operation *op) {
70-
if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
108+
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
109+
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
110+
applyToAccDataConstruct))
71111
return;
72112

73113
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -78,14 +118,15 @@ struct LegalizeDataInRegion
78118
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
79119
} else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
80120
collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
121+
} else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
122+
collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
123+
} else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124+
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
125+
} else {
126+
llvm_unreachable("unsupported acc region op");
81127
}
82128
});
83129
}
84130
};
85131

86132
} // end anonymous namespace
87-
88-
std::unique_ptr<OperationPass<func::FuncOp>>
89-
mlir::acc::createLegalizeDataInRegion() {
90-
return std::make_unique<LegalizeDataInRegion>();
91-
}

0 commit comments

Comments
 (0)