Skip to content

[acc] Improve LegalizeDataValues pass to handle data constructs #112990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion flang/test/Fir/OpenACC/legalize-data.fir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s
// RUN: fir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s

func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
Expand All @@ -22,3 +22,36 @@ func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
// CHECK: acc.yield
// CHECK: }
// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}

// -----

func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
acc.data dataOperands(%1 : !fir.ref<i32>) {
%c0_i32 = arith.constant 0 : i32
hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
acc.serial {
hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
acc.yield
}
acc.terminator
}
acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
return
}

// CHECK-LABEL: func.func @_QPsub1
// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
// CHECK: acc.data dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: hlfir.assign %[[C0]] to %0#0 : i32, !fir.ref<i32>
// CHECK: acc.serial {
// CHECK: hlfir.assign %[[C0]] to %[[COPYIN]] : i32, !fir.ref<i32>
// CHECK: acc.yield
// CHECK: }
// CHECK: acc.terminator
// CHECK: }
// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/OpenACC/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@
mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
#define OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS \
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
mlir::acc::DataOp, mlir::acc::DeclareOp
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \
mlir::acc::DeclareExitOp
#define ACC_DATA_CONSTRUCT_OPS \
OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS
#define ACC_COMPUTE_LOOP_AND_DATA_CONSTRUCT_OPS \
Expand Down
7 changes: 2 additions & 5 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

#include "mlir/Pass/Pass.h"

#define GEN_PASS_DECL
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"

namespace mlir {

namespace func {
Expand All @@ -22,8 +19,8 @@ class FuncOp;

namespace acc {

/// Create a pass to replace ssa values in region with device/host values.
std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
#define GEN_PASS_DECL
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"

/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
Expand Down
14 changes: 8 additions & 6 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@

include "mlir/Pass/PassBase.td"

def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
let summary = "Legalize the data in the compute region";
def LegalizeDataValuesInRegion : Pass<"openacc-legalize-data-values", "mlir::func::FuncOp"> {
let summary = "Legalizes SSA values in compute regions with results from data clause operations";
let description = [{
This pass replace uses of varPtr in the compute region with their accPtr
gathered from the data clause operands.
This pass replace uses of the `varPtr` in compute regions (kernels,
parallel, serial) with the result of data clause operations (`accPtr`).
}];
let options = [
Option<"hostToDevice", "host-to-device", "bool", "true",
"Replace varPtr uses with accPtr if true. Replace accPtr uses with "
"varPtr if false">
"varPtr if false">,
Option<"applyToAccDataConstruct", "apply-to-acc-data-construct", "bool", "true",
"Replaces varPtr uses with accPtr for acc compute regions contained "
"within acc.data or acc.declare region.">
];
let constructor = "::mlir::acc::createLegalizeDataInRegion()";
}

#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIROpenACCTransforms
LegalizeData.cpp
LegalizeDataValues.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- LegalizeData.cpp - -------------------------------------------------===//
//===- LegalizeDataValues.cpp - -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -12,10 +12,11 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
namespace acc {
#define GEN_PASS_DEF_LEGALIZEDATAINREGION
#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
} // namespace acc
} // namespace mlir
Expand All @@ -24,6 +25,17 @@ using namespace mlir;

namespace {

static bool insideAccComputeRegion(mlir::Operation *op) {
mlir::Operation *parent{op->getParentOp()};
while (parent) {
if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
return true;
}
parent = parent->getParentOp();
}
return false;
}

static void collectPtrs(mlir::ValueRange operands,
llvm::SmallVector<std::pair<Value, Value>> &values,
bool hostToDevice) {
Expand All @@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
}
}

template <typename Op>
static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
Region &outerRegion) {
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
if constexpr (std::is_same_v<Op, acc::DataOp> ||
std::is_same_v<Op, acc::DeclareOp>) {
// For data construct regions, only replace uses in contained compute
// regions.
if (insideAccComputeRegion(use.getOwner())) {
use.set(replacement);
}
} else {
use.set(replacement);
}
}
}
}

template <typename Op>
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
llvm::SmallVector<std::pair<Value, Value>> values;
Expand All @@ -48,26 +79,35 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
collectPtrs(op.getPrivateOperands(), values, hostToDevice);
} else {
collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp>) {
collectPtrs(op.getReductionOperands(), values, hostToDevice);
collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
}
}

for (auto p : values)
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
op.getRegion());
}

struct LegalizeDataInRegion
: public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
class LegalizeDataValuesInRegion
: public acc::impl::LegalizeDataValuesInRegionBase<
LegalizeDataValuesInRegion> {
public:
using LegalizeDataValuesInRegionBase<
LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;

void runOnOperation() override {
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();

funcOp.walk([&](Operation *op) {
if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
applyToAccDataConstruct))
return;

if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
Expand All @@ -78,14 +118,15 @@ struct LegalizeDataInRegion
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
} else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
} else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
} else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
} else {
llvm_unreachable("unsupported acc region op");
}
});
}
};

} // end anonymous namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::acc::createLegalizeDataInRegion() {
return std::make_unique<LegalizeDataInRegion>();
}
30 changes: 28 additions & 2 deletions mlir/test/Dialect/OpenACC/legalize-data.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
// RUN: mlir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s --check-prefixes=CHECK,DEVICE
// RUN: mlir-opt -split-input-file --openacc-legalize-data-values=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST

func.func @test(%a: memref<10xf32>, %i : index) {
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
Expand Down Expand Up @@ -61,6 +61,32 @@ func.func @test(%a: memref<10xf32>, %i : index) {

// -----

func.func @test(%a: memref<10xf32>, %i : index) {
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
acc.data dataOperands(%create : memref<10xf32>) {
%c0 = arith.constant 0.000000e+00 : f32
memref.store %c0, %a[%i] : memref<10xf32>
acc.serial {
%cs = memref.load %a[%i] : memref<10xf32>
acc.yield
}
acc.terminator
}
return
}

// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
// CHECK: acc.data dataOperands(%[[CREATE]] : memref<10xf32>) {
// CHECK: memref.store %{{.*}}, %[[A]][%[[I]]] : memref<10xf32>
// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
// CHECK: acc.terminator
// CHECK: }

// -----

func.func @test(%a: memref<10xf32>) {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
Expand Down
Loading