Skip to content

Commit 4c9717c

Browse files
authored
[mlir][openacc] Add private/reduction in legalize data pass (#80882)
This is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations.
1 parent 12aad1a commit 4c9717c

File tree

2 files changed

+138
-5
lines changed

2 files changed

+138
-5
lines changed

mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ using namespace mlir;
2424

2525
namespace {
2626

27-
template <typename Op>
28-
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
29-
llvm::SmallVector<std::pair<Value, Value>> values;
30-
for (auto operand : op.getDataClauseOperands()) {
27+
static void collectPtrs(mlir::ValueRange operands,
28+
llvm::SmallVector<std::pair<Value, Value>> &values,
29+
bool hostToDevice) {
30+
for (auto operand : operands) {
3131
Value varPtr = acc::getVarPtr(operand.getDefiningOp());
3232
Value accPtr = acc::getAccPtr(operand.getDefiningOp());
3333
if (varPtr && accPtr) {
@@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
3737
values.push_back({accPtr, varPtr});
3838
}
3939
}
40+
}
41+
42+
template <typename Op>
43+
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
44+
llvm::SmallVector<std::pair<Value, Value>> values;
45+
46+
if constexpr (std::is_same_v<Op, acc::LoopOp>) {
47+
collectPtrs(op.getReductionOperands(), values, hostToDevice);
48+
collectPtrs(op.getPrivateOperands(), values, hostToDevice);
49+
} else {
50+
collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
51+
if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
52+
collectPtrs(op.getReductionOperands(), values, hostToDevice);
53+
collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
54+
collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
55+
}
56+
}
4057

4158
for (auto p : values)
4259
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
@@ -50,7 +67,7 @@ struct LegalizeDataInRegion
5067
bool replaceHostVsDevice = this->hostToDevice.getValue();
5168

5269
funcOp.walk([&](Operation *op) {
53-
if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
70+
if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
5471
return;
5572

5673
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -59,6 +76,8 @@ struct LegalizeDataInRegion
5976
collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
6077
} else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
6178
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
79+
} else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
80+
collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
6281
}
6382
});
6483
}

mlir/test/Dialect/OpenACC/legalize-data.mlir

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) {
8686
// CHECK: }
8787
// CHECK: acc.yield
8888
// CHECK: }
89+
90+
// -----
91+
92+
acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
93+
^bb0(%arg0: memref<10xf32>):
94+
%0 = memref.alloc() : memref<10xf32>
95+
acc.yield %0 : memref<10xf32>
96+
} destroy {
97+
^bb0(%arg0: memref<10xf32>):
98+
memref.dealloc %arg0 : memref<10xf32>
99+
acc.terminator
100+
}
101+
102+
func.func @test(%a: memref<10xf32>) {
103+
%lb = arith.constant 0 : index
104+
%st = arith.constant 1 : index
105+
%c10 = arith.constant 10 : index
106+
%p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
107+
acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
108+
acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
109+
%ci = memref.load %a[%i] : memref<10xf32>
110+
acc.yield
111+
}
112+
acc.yield
113+
}
114+
return
115+
}
116+
117+
// CHECK: func.func @test
118+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
119+
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
120+
// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
121+
// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
122+
// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
123+
// CHECK: acc.yield
124+
// CHECK: }
125+
// CHECK: acc.yield
126+
// CHECK: }
127+
128+
// -----
129+
130+
acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
131+
^bb0(%arg0: memref<10xf32>):
132+
%0 = memref.alloc() : memref<10xf32>
133+
acc.yield %0 : memref<10xf32>
134+
} destroy {
135+
^bb0(%arg0: memref<10xf32>):
136+
memref.dealloc %arg0 : memref<10xf32>
137+
acc.terminator
138+
}
139+
140+
func.func @test(%a: memref<10xf32>) {
141+
%lb = arith.constant 0 : index
142+
%st = arith.constant 1 : index
143+
%c10 = arith.constant 10 : index
144+
%p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
145+
acc.parallel {
146+
acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
147+
%ci = memref.load %a[%i] : memref<10xf32>
148+
acc.yield
149+
}
150+
acc.yield
151+
}
152+
return
153+
}
154+
155+
// CHECK: func.func @test
156+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
157+
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
158+
// CHECK: acc.parallel {
159+
// CHECK: acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
160+
// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
161+
// CHECK: acc.yield
162+
// CHECK: }
163+
// CHECK: acc.yield
164+
// CHECK: }
165+
166+
// -----
167+
168+
acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
169+
^bb0(%arg0: memref<10xf32>):
170+
%0 = memref.alloc() : memref<10xf32>
171+
acc.yield %0 : memref<10xf32>
172+
} destroy {
173+
^bb0(%arg0: memref<10xf32>):
174+
memref.dealloc %arg0 : memref<10xf32>
175+
acc.terminator
176+
}
177+
178+
func.func @test(%a: memref<10xf32>) {
179+
%lb = arith.constant 0 : index
180+
%st = arith.constant 1 : index
181+
%c10 = arith.constant 10 : index
182+
%p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
183+
acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
184+
acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
185+
%ci = memref.load %a[%i] : memref<10xf32>
186+
acc.yield
187+
}
188+
acc.yield
189+
}
190+
return
191+
}
192+
193+
// CHECK: func.func @test
194+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
195+
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
196+
// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
197+
// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
198+
// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
199+
// CHECK: acc.yield
200+
// CHECK: }
201+
// CHECK: acc.yield
202+
// CHECK: }

0 commit comments

Comments
 (0)