Skip to content

Commit ab4a4d6

Browse files
committed
Update to address recent reviewer comments
- Added missing headers - Modified pass to work on FuncOps - Extended OMPDescriptorMapInfoGen pass test to cover more of the passes behaviour - Removed some usage of the Mutable interface inside of the pass to minimize it's usage.
1 parent f4d56d6 commit ab4a4d6

File tree

4 files changed

+80
-28
lines changed

4 files changed

+80
-28
lines changed

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
319319
}
320320

321321
def OMPDescriptorMapInfoGenPass
322-
: Pass<"omp-descriptor-map-info-gen", "mlir::ModuleOp"> {
322+
: Pass<"omp-descriptor-map-info-gen", "mlir::func::FuncOp"> {
323323
let summary = "expands OpenMP MapInfo operations containing descriptors";
324324
let description = [{
325325
Expands MapInfo operations containing descriptor types into multiple

flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
//===-- CodeGenOpenMP.cpp -------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10+
//
11+
//===----------------------------------------------------------------------===//
12+
113
#include "flang/Optimizer/CodeGen/CodeGenOpenMP.h"
214

315
#include "flang/Optimizer/Builder/FIRBuilder.h"

flang/lib/Optimizer/Transforms/OMPDescriptorMapInfoGen.cpp

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
1+
//===- OMPDescriptorMapInfoGen.cpp
2+
//---------------------------------------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
//===----------------------------------------------------------------------===//
11+
/// \file
12+
/// An OpenMP dialect related pass for FIR/HLFIR which expands MapInfoOp's
13+
/// containing descriptor related types (fir::BoxType's) into multiple
14+
/// MapInfoOp's containing the parent descriptor and pointer member components
15+
/// for individual mapping, treating the descriptor type as a record type for
16+
/// later lowering in the OpenMP dialect.
17+
//===----------------------------------------------------------------------===//
18+
119
#include "flang/Optimizer/Builder/FIRBuilder.h"
220
#include "flang/Optimizer/Dialect/FIRType.h"
321
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
422
#include "flang/Optimizer/Transforms/Passes.h"
523
#include "mlir/Dialect/Func/IR/FuncOps.h"
6-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
724
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
825
#include "mlir/IR/BuiltinDialect.h"
926
#include "mlir/IR/BuiltinOps.h"
@@ -75,11 +92,6 @@ class OMPDescriptorMapInfoGenPass
7592
// TODO: map the addendum segment of the descriptor, similarly to the
7693
// above base address/data pointer member.
7794

78-
op.getVarPtrMutable().assign(descriptor);
79-
op.setVarType(fir::unwrapRefType(descriptor.getType()));
80-
op.getMembersMutable().append(baseAddr);
81-
op.getBoundsMutable().assign(llvm::SmallVector<mlir::Value>{});
82-
8395
if (auto mapClauseOwner =
8496
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
8597
llvm::SmallVector<mlir::Value> newMapOps;
@@ -98,23 +110,34 @@ class OMPDescriptorMapInfoGenPass
98110
if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target))
99111
targetOp.getRegion().insertArgument(i, baseAddr.getType(), loc);
100112
}
101-
102113
newMapOps.push_back(mapOperandsArr[i]);
103114
}
104-
105115
mapClauseOwner.getMapOperandsMutable().assign(newMapOps);
106116
}
117+
118+
mlir::Value newDescParentMapOp = builder.create<mlir::omp::MapInfoOp>(
119+
op->getLoc(), op.getResult().getType(), descriptor,
120+
fir::unwrapRefType(descriptor.getType()), mlir::Value{},
121+
mlir::SmallVector<mlir::Value>{baseAddr},
122+
mlir::SmallVector<mlir::Value>{},
123+
builder.getIntegerAttr(builder.getIntegerType(64, false),
124+
op.getMapType().value()),
125+
op.getMapCaptureTypeAttr(), op.getNameAttr());
126+
op.replaceAllUsesWith(newDescParentMapOp);
127+
op->erase();
107128
}
108129

109130
// This pass executes on mlir::ModuleOp's finding omp::MapInfoOp's containing
110131
// descriptor based types (allocatables, pointers, assumed shape etc.) and
111132
// expanding them into multiple omp::MapInfoOp's for each pointer member
112133
// contained within the descriptor.
113134
void runOnOperation() override {
114-
fir::KindMapping kindMap = fir::getKindMapping(getOperation());
115-
fir::FirOpBuilder builder{getOperation(), std::move(kindMap)};
135+
mlir::func::FuncOp func = getOperation();
136+
mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
137+
fir::KindMapping kindMap = fir::getKindMapping(module);
138+
fir::FirOpBuilder builder{module, std::move(kindMap)};
116139

117-
getOperation()->walk([&](mlir::omp::MapInfoOp op) {
140+
func->walk([&](mlir::omp::MapInfoOp op) {
118141
if (fir::isTypeWithDescriptor(op.getVarType()) ||
119142
mlir::isa_and_present<fir::BoxAddrOp>(
120143
op.getVarPtr().getDefiningOp())) {
Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,44 @@
11
// RUN: fir-opt --omp-descriptor-map-info-gen %s | FileCheck %s
22

33
module attributes {omp.is_target_device = false} {
4-
func.func @test_descriptor_expansion_pass() {
5-
%0 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "test", uniq_name = "_QFEtest"}
4+
func.func @test_descriptor_expansion_pass(%arg0: !fir.box<!fir.array<?xi32>>) {
5+
%0 = fir.alloca !fir.box<!fir.heap<i32>>
66
%1 = fir.zero_bits !fir.heap<i32>
7-
%2 = fir.embox %1 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
8-
fir.store %2 to %0 : !fir.ref<!fir.box<!fir.heap<i32>>>
9-
%3:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEtest"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
10-
%4 = fir.allocmem i32 {fir.must_be_heap = true, uniq_name = "_QFEtest.alloc"}
11-
%5 = fir.embox %4 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
12-
fir.store %5 to %3#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
13-
%6 = omp.map_info var_ptr(%3#1 : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.box<!fir.heap<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.box<!fir.heap<i32>>> {name = "test"}
14-
omp.target map_entries(%6 -> %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>) {
15-
^bb0(%arg0: !fir.ref<!fir.box<!fir.heap<i32>>>):
7+
%2:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<intent_out>, uniq_name = "test"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
8+
%3 = fir.embox %1 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
9+
fir.store %3 to %0 : !fir.ref<!fir.box<!fir.heap<i32>>>
10+
%4:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "test2"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
11+
%5 = fir.allocmem i32 {fir.must_be_heap = true}
12+
%6 = fir.embox %5 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
13+
fir.store %6 to %4#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
14+
%c0 = arith.constant 1 : index
15+
%c1 = arith.constant 0 : index
16+
%c2 = arith.constant 10 : index
17+
%dims:3 = fir.box_dims %2#1, %c1 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
18+
%bounds = omp.bounds lower_bound(%c1 : index) upper_bound(%c2 : index) extent(%dims#1 : index) stride(%dims#2 : index) start_idx(%c0 : index) {stride_in_bytes = true}
19+
%7 = fir.box_addr %2#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
20+
%8 = omp.map_info var_ptr(%4#1 : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.box<!fir.heap<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.box<!fir.heap<i32>>>
21+
%9 = omp.map_info var_ptr(%7 : !fir.ref<!fir.array<?xi32>>, !fir.array<?xi32>) map_clauses(from) capture(ByRef) bounds(%bounds) -> !fir.ref<!fir.array<?xi32>>
22+
omp.target map_entries(%8 -> %arg1, %9 -> %arg2 : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.array<?xi32>>) {
23+
^bb0(%arg1: !fir.ref<!fir.box<!fir.heap<i32>>>, %arg2: !fir.ref<!fir.array<?xi32>>):
1624
omp.terminator
1725
}
1826
return
1927
}
2028
}
2129

22-
// CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEtest"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
23-
// CHECK: %[[BASE_ADDR_OFF:.*]] = fir.box_offset %[[DECLARE]]#1 base_addr : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> !fir.llvm_ptr<!fir.ref<i32>>
30+
// CHECK: func.func @test_descriptor_expansion_pass(%[[ARG0:.*]]: !fir.box<!fir.array<?xi32>>) {
31+
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
32+
// CHECK: %[[ALLOCA2:.*]] = fir.alloca !fir.box<!fir.heap<i32>>
33+
// CHECK: %[[DECLARE1:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<intent_out>, uniq_name = "test"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
34+
// CHECK: %[[DECLARE2:.*]]:2 = hlfir.declare %[[ALLOCA2]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "test2"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
35+
// CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound(%{{.*}} : index) upper_bound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) start_idx(%{{.*}} : index) {stride_in_bytes = true}
36+
// CHECK: %[[BASE_ADDR_OFF:.*]] = fir.box_offset %[[DECLARE2]]#1 base_addr : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> !fir.llvm_ptr<!fir.ref<i32>>
2437
// CHECK: %[[DESC_MEMBER_MAP:.*]] = omp.map_info var_ptr(%[[BASE_ADDR_OFF]] : !fir.llvm_ptr<!fir.ref<i32>>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
25-
// CHECK: %[[DESC_PARENT_MAP:.*]] = omp.map_info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.box<!fir.heap<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DESC_MEMBER_MAP]] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.heap<i32>>> {name = "test"}
26-
// CHECK: omp.target map_entries(%[[DESC_MEMBER_MAP]] -> %{{.*}}, %[[DESC_PARENT_MAP]] -> %{{.*}} : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.heap<i32>>>) {
27-
// CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<i32>>, %{{.*}}: !fir.ref<!fir.box<!fir.heap<i32>>>):
38+
// CHECK: %[[DESC_PARENT_MAP:.*]] = omp.map_info var_ptr(%[[DECLARE2]]#1 : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.box<!fir.heap<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DESC_MEMBER_MAP]] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.heap<i32>>>
39+
// CHECK: fir.store %[[DECLARE1]]#1 to %[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
40+
// CHECK: %[[BASE_ADDR_OFF_2:.*]] = fir.box_offset %[[ALLOCA]] base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
41+
// CHECK: %[[DESC_MEMBER_MAP_2:.*]] = omp.map_info var_ptr(%[[BASE_ADDR_OFF_2]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.array<?xi32>) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
42+
// CHECK: %[[DESC_PARENT_MAP_2:.*]] = omp.map_info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(from) capture(ByRef) members(%15 : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>>
43+
// CHECK: omp.target map_entries(%[[DESC_MEMBER_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG3:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
44+
// CHECK: ^bb0(%[[ARG1]]: !fir.llvm_ptr<!fir.ref<i32>>, %[[ARG2]]: !fir.ref<!fir.box<!fir.heap<i32>>>, %[[ARG3]]: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %[[ARG4]]: !fir.ref<!fir.array<?xi32>>):

0 commit comments

Comments
 (0)