Skip to content

Commit ac9ee61

Browse files
[acc] Improve LegalizeDataValues pass to handle data constructs (#112990)
Renames LegalizeData to LegalizeDataValues since this pass fixes up SSA values. LegalizeData suggested that it fixed data mapping. This change also adds support to fix up ssa values for data clause operations. Effectively, compute regions within a data region use the ssa values from data operations also. The ssa values within data regions but not within compute regions are not updated. This change is to support the requirement in the OpenACC spec which notes that a visible data clause is not just one on the current compute construct but on the lexically containing data construct or visible declare directive.
1 parent 3277c7c commit ac9ee61

File tree

7 files changed

+128
-29
lines changed

7 files changed

+128
-29
lines changed

flang/test/Fir/OpenACC/legalize-data.fir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s
1+
// RUN: fir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s
22

33
func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
44
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -22,3 +22,36 @@ func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
2222
// CHECK: acc.yield
2323
// CHECK: }
2424
// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
25+
26+
// -----
27+
28+
func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
29+
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
30+
%1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
31+
acc.data dataOperands(%1 : !fir.ref<i32>) {
32+
%c0_i32 = arith.constant 0 : i32
33+
hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
34+
acc.serial {
35+
hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
36+
acc.yield
37+
}
38+
acc.terminator
39+
}
40+
acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
41+
return
42+
}
43+
44+
// CHECK-LABEL: func.func @_QPsub1
45+
// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
46+
// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
47+
// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
48+
// CHECK: acc.data dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
49+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
50+
// CHECK: hlfir.assign %[[C0]] to %0#0 : i32, !fir.ref<i32>
51+
// CHECK: acc.serial {
52+
// CHECK: hlfir.assign %[[C0]] to %[[COPYIN]] : i32, !fir.ref<i32>
53+
// CHECK: acc.yield
54+
// CHECK: }
55+
// CHECK: acc.terminator
56+
// CHECK: }
57+
// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@
5656
mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp
5757
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
5858
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
59-
#define OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS \
59+
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
6060
mlir::acc::DataOp, mlir::acc::DeclareOp
6161
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
6262
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
6363
mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \
6464
mlir::acc::DeclareExitOp
6565
#define ACC_DATA_CONSTRUCT_OPS \
66-
OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
66+
ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
6767
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \
6868
ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS
6969
#define ACC_COMPUTE_LOOP_AND_DATA_CONSTRUCT_OPS \

mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
#include "mlir/Pass/Pass.h"
1313

14-
#define GEN_PASS_DECL
15-
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
16-
1714
namespace mlir {
1815

1916
namespace func {
@@ -22,8 +19,8 @@ class FuncOp;
2219

2320
namespace acc {
2421

25-
/// Create a pass to replace ssa values in region with device/host values.
26-
std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
22+
#define GEN_PASS_DECL
23+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
2724

2825
/// Generate the code for registering conversion passes.
2926
#define GEN_PASS_REGISTRATION

mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
15-
let summary = "Legalize the data in the compute region";
14+
def LegalizeDataValuesInRegion : Pass<"openacc-legalize-data-values", "mlir::func::FuncOp"> {
15+
let summary = "Legalizes SSA values in compute regions with results from data clause operations";
1616
let description = [{
17-
This pass replace uses of varPtr in the compute region with their accPtr
18-
gathered from the data clause operands.
17+
This pass replace uses of the `varPtr` in compute regions (kernels,
18+
parallel, serial) with the result of data clause operations (`accPtr`).
1919
}];
2020
let options = [
2121
Option<"hostToDevice", "host-to-device", "bool", "true",
2222
"Replace varPtr uses with accPtr if true. Replace accPtr uses with "
23-
"varPtr if false">
23+
"varPtr if false">,
24+
Option<"applyToAccDataConstruct", "apply-to-acc-data-construct", "bool", "true",
25+
"Replaces varPtr uses with accPtr for acc compute regions contained "
26+
"within acc.data or acc.declare region.">
2427
];
25-
let constructor = "::mlir::acc::createLegalizeDataInRegion()";
2628
}
2729

2830
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES

mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_mlir_dialect_library(MLIROpenACCTransforms
2-
LegalizeData.cpp
2+
LegalizeDataValues.cpp
33

44
ADDITIONAL_HEADER_DIRS
55
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC

mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp renamed to mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- LegalizeData.cpp - -------------------------------------------------===//
1+
//===- LegalizeDataValues.cpp - -------------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,10 +12,11 @@
1212
#include "mlir/Dialect/OpenACC/OpenACC.h"
1313
#include "mlir/Pass/Pass.h"
1414
#include "mlir/Transforms/RegionUtils.h"
15+
#include "llvm/Support/ErrorHandling.h"
1516

1617
namespace mlir {
1718
namespace acc {
18-
#define GEN_PASS_DEF_LEGALIZEDATAINREGION
19+
#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
1920
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
2021
} // namespace acc
2122
} // namespace mlir
@@ -24,6 +25,17 @@ using namespace mlir;
2425

2526
namespace {
2627

28+
static bool insideAccComputeRegion(mlir::Operation *op) {
29+
mlir::Operation *parent{op->getParentOp()};
30+
while (parent) {
31+
if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
32+
return true;
33+
}
34+
parent = parent->getParentOp();
35+
}
36+
return false;
37+
}
38+
2739
static void collectPtrs(mlir::ValueRange operands,
2840
llvm::SmallVector<std::pair<Value, Value>> &values,
2941
bool hostToDevice) {
@@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
3951
}
4052
}
4153

54+
template <typename Op>
55+
static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
56+
Region &outerRegion) {
57+
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
58+
if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
59+
if constexpr (std::is_same_v<Op, acc::DataOp> ||
60+
std::is_same_v<Op, acc::DeclareOp>) {
61+
// For data construct regions, only replace uses in contained compute
62+
// regions.
63+
if (insideAccComputeRegion(use.getOwner())) {
64+
use.set(replacement);
65+
}
66+
} else {
67+
use.set(replacement);
68+
}
69+
}
70+
}
71+
}
72+
4273
template <typename Op>
4374
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
4475
llvm::SmallVector<std::pair<Value, Value>> values;
@@ -48,26 +79,35 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
4879
collectPtrs(op.getPrivateOperands(), values, hostToDevice);
4980
} else {
5081
collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
51-
if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
82+
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83+
!std::is_same_v<Op, acc::DataOp> &&
84+
!std::is_same_v<Op, acc::DeclareOp>) {
5285
collectPtrs(op.getReductionOperands(), values, hostToDevice);
5386
collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
5487
collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
5588
}
5689
}
5790

5891
for (auto p : values)
59-
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
92+
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
93+
op.getRegion());
6094
}
6195

62-
struct LegalizeDataInRegion
63-
: public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
96+
class LegalizeDataValuesInRegion
97+
: public acc::impl::LegalizeDataValuesInRegionBase<
98+
LegalizeDataValuesInRegion> {
99+
public:
100+
using LegalizeDataValuesInRegionBase<
101+
LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
64102

65103
void runOnOperation() override {
66104
func::FuncOp funcOp = getOperation();
67105
bool replaceHostVsDevice = this->hostToDevice.getValue();
68106

69107
funcOp.walk([&](Operation *op) {
70-
if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
108+
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
109+
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
110+
applyToAccDataConstruct))
71111
return;
72112

73113
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -78,14 +118,15 @@ struct LegalizeDataInRegion
78118
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
79119
} else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
80120
collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
121+
} else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
122+
collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
123+
} else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124+
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
125+
} else {
126+
llvm_unreachable("unsupported acc region op");
81127
}
82128
});
83129
}
84130
};
85131

86132
} // end anonymous namespace
87-
88-
std::unique_ptr<OperationPass<func::FuncOp>>
89-
mlir::acc::createLegalizeDataInRegion() {
90-
return std::make_unique<LegalizeDataInRegion>();
91-
}

mlir/test/Dialect/OpenACC/legalize-data.mlir

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
2-
// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
1+
// RUN: mlir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s --check-prefixes=CHECK,DEVICE
2+
// RUN: mlir-opt -split-input-file --openacc-legalize-data-values=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
33

44
func.func @test(%a: memref<10xf32>, %i : index) {
55
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
@@ -61,6 +61,32 @@ func.func @test(%a: memref<10xf32>, %i : index) {
6161

6262
// -----
6363

64+
func.func @test(%a: memref<10xf32>, %i : index) {
65+
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
66+
acc.data dataOperands(%create : memref<10xf32>) {
67+
%c0 = arith.constant 0.000000e+00 : f32
68+
memref.store %c0, %a[%i] : memref<10xf32>
69+
acc.serial {
70+
%cs = memref.load %a[%i] : memref<10xf32>
71+
acc.yield
72+
}
73+
acc.terminator
74+
}
75+
return
76+
}
77+
78+
// CHECK-LABEL: func.func @test
79+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
80+
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
81+
// CHECK: acc.data dataOperands(%[[CREATE]] : memref<10xf32>) {
82+
// CHECK: memref.store %{{.*}}, %[[A]][%[[I]]] : memref<10xf32>
83+
// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
84+
// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
85+
// CHECK: acc.terminator
86+
// CHECK: }
87+
88+
// -----
89+
6490
func.func @test(%a: memref<10xf32>) {
6591
%lb = arith.constant 0 : index
6692
%st = arith.constant 1 : index

0 commit comments

Comments
 (0)