1
- // ===- LegalizeData .cpp - ------ -------------------------------------------===//
1
+ // ===- LegalizeDataValues .cpp - -------------------------------------------===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
12
12
#include " mlir/Dialect/OpenACC/OpenACC.h"
13
13
#include " mlir/Pass/Pass.h"
14
14
#include " mlir/Transforms/RegionUtils.h"
15
+ #include " llvm/Support/ErrorHandling.h"
15
16
16
17
namespace mlir {
17
18
namespace acc {
18
- #define GEN_PASS_DEF_LEGALIZEDATAINREGION
19
+ #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
19
20
#include " mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
20
21
} // namespace acc
21
22
} // namespace mlir
@@ -24,6 +25,17 @@ using namespace mlir;
24
25
25
26
namespace {
26
27
28
+ static bool insideAccComputeRegion (mlir::Operation *op) {
29
+ mlir::Operation *parent{op->getParentOp ()};
30
+ while (parent) {
31
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
32
+ return true ;
33
+ }
34
+ parent = parent->getParentOp ();
35
+ }
36
+ return false ;
37
+ }
38
+
27
39
static void collectPtrs (mlir::ValueRange operands,
28
40
llvm::SmallVector<std::pair<Value, Value>> &values,
29
41
bool hostToDevice) {
@@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
39
51
}
40
52
}
41
53
54
+ template <typename Op>
55
+ static void replaceAllUsesInAccComputeRegionsWith (Value orig, Value replacement,
56
+ Region &outerRegion) {
57
+ for (auto &use : llvm::make_early_inc_range (orig.getUses ())) {
58
+ if (outerRegion.isAncestor (use.getOwner ()->getParentRegion ())) {
59
+ if constexpr (std::is_same_v<Op, acc::DataOp> ||
60
+ std::is_same_v<Op, acc::DeclareOp>) {
61
+ // For data construct regions, only replace uses in contained compute
62
+ // regions.
63
+ if (insideAccComputeRegion (use.getOwner ())) {
64
+ use.set (replacement);
65
+ }
66
+ } else {
67
+ use.set (replacement);
68
+ }
69
+ }
70
+ }
71
+ }
72
+
42
73
template <typename Op>
43
74
static void collectAndReplaceInRegion (Op &op, bool hostToDevice) {
44
75
llvm::SmallVector<std::pair<Value, Value>> values;
@@ -48,26 +79,35 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
48
79
collectPtrs (op.getPrivateOperands (), values, hostToDevice);
49
80
} else {
50
81
collectPtrs (op.getDataClauseOperands (), values, hostToDevice);
51
- if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
82
+ if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83
+ !std::is_same_v<Op, acc::DataOp> &&
84
+ !std::is_same_v<Op, acc::DeclareOp>) {
52
85
collectPtrs (op.getReductionOperands (), values, hostToDevice);
53
86
collectPtrs (op.getGangPrivateOperands (), values, hostToDevice);
54
87
collectPtrs (op.getGangFirstPrivateOperands (), values, hostToDevice);
55
88
}
56
89
}
57
90
58
91
for (auto p : values)
59
- replaceAllUsesInRegionWith (std::get<0 >(p), std::get<1 >(p), op.getRegion ());
92
+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
93
+ op.getRegion ());
60
94
}
61
95
62
- struct LegalizeDataInRegion
63
- : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
96
+ class LegalizeDataValuesInRegion
97
+ : public acc::impl::LegalizeDataValuesInRegionBase<
98
+ LegalizeDataValuesInRegion> {
99
+ public:
100
+ using LegalizeDataValuesInRegionBase<
101
+ LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
64
102
65
103
void runOnOperation () override {
66
104
func::FuncOp funcOp = getOperation ();
67
105
bool replaceHostVsDevice = this ->hostToDevice .getValue ();
68
106
69
107
funcOp.walk ([&](Operation *op) {
70
- if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
108
+ if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
109
+ !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
110
+ applyToAccDataConstruct))
71
111
return ;
72
112
73
113
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -78,14 +118,15 @@ struct LegalizeDataInRegion
78
118
collectAndReplaceInRegion (kernelsOp, replaceHostVsDevice);
79
119
} else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
80
120
collectAndReplaceInRegion (loopOp, replaceHostVsDevice);
121
+ } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
122
+ collectAndReplaceInRegion (dataOp, replaceHostVsDevice);
123
+ } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124
+ collectAndReplaceInRegion (declareOp, replaceHostVsDevice);
125
+ } else {
126
+ llvm_unreachable (" unsupported acc region op" );
81
127
}
82
128
});
83
129
}
84
130
};
85
131
86
132
} // end anonymous namespace
87
-
88
- std::unique_ptr<OperationPass<func::FuncOp>>
89
- mlir::acc::createLegalizeDataInRegion () {
90
- return std::make_unique<LegalizeDataInRegion>();
91
- }
0 commit comments