Skip to content

Commit 8701b17

Browse files
TIFitisskatrak
andauthored
[MLIR][OpenMP] Changes to function-filtering pass (#71850)
Currently, when deleting the device functions in the second stage of filtering during MLIR to LLVM translation we can end up with invalid calls to these functions. This is because of the removal of the EarlyOutliningPass which would have otherwise gotten rid of any such calls. This patch aims to alter the function filtering pass in the following way: - Any host function is completely removed. - Call to the host function are also removed and their uses replaced with Undef values. - Any host function with target region code is marked to be removed during the the second stage. - Calls to such functions are still removed and their uses replaced with Undef values. Co-authored-by: Sergio Afonso <[email protected]>
1 parent eb3c02f commit 8701b17

File tree

6 files changed

+52
-29
lines changed

6 files changed

+52
-29
lines changed

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ def OMPFunctionFiltering : Pass<"omp-function-filtering"> {
330330
"for the target device.";
331331
let constructor = "::fir::createOMPFunctionFilteringPass()";
332332
let dependentDialects = [
333-
"mlir::func::FuncDialect"
333+
"mlir::func::FuncDialect",
334+
"fir::FIROpsDialect"
334335
];
335336
}
336337

flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "flang/Optimizer/Dialect/FIRDialect.h"
1415
#include "flang/Optimizer/Transforms/Passes.h"
1516

1617
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -33,6 +34,8 @@ class OMPFunctionFilteringPass
3334
OMPFunctionFilteringPass() = default;
3435

3536
void runOnOperation() override {
37+
MLIRContext *context = &getContext();
38+
OpBuilder opBuilder(context);
3639
auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
3740
if (!op || !op.getIsTargetDevice())
3841
return;
@@ -46,8 +49,6 @@ class OMPFunctionFilteringPass
4649
->walk<WalkOrder::PreOrder>(
4750
[&](omp::TargetOp) { return WalkResult::interrupt(); })
4851
.wasInterrupted();
49-
if (hasTargetRegion)
50-
return;
5152

5253
omp::DeclareTargetDeviceType declareType =
5354
omp::DeclareTargetDeviceType::host;
@@ -56,18 +57,31 @@ class OMPFunctionFilteringPass
5657
if (declareTargetOp && declareTargetOp.isDeclareTarget())
5758
declareType = declareTargetOp.getDeclareTargetDeviceType();
5859

59-
// Filtering a function here means removing its body and explicitly
60-
// setting its omp.declare_target attribute, so that following
61-
// translation/lowering/transformation passes will skip processing its
62-
// contents, but preventing the calls to undefined symbols that could
63-
// result if the function were deleted. The second stage of function
64-
// filtering, at the MLIR to LLVM IR translation level, will remove these
65-
// from the IR thanks to the mismatch between the omp.declare_target
66-
// attribute and the target device.
60+
// Filtering a function here means deleting it if it doesn't contain a
61+
// target region. Else we explicitly set the omp.declare_target
62+
// attribute. The second stage of function filtering at the MLIR to LLVM
63+
// IR translation level will remove functions that contain the target
64+
// region from the generated llvm IR.
6765
if (declareType == omp::DeclareTargetDeviceType::host) {
68-
funcOp.eraseBody();
69-
funcOp.setVisibility(SymbolTable::Visibility::Private);
70-
if (declareTargetOp)
66+
SymbolTable::UseRange funcUses = *funcOp.getSymbolUses(op);
67+
for (SymbolTable::SymbolUse use : funcUses) {
68+
Operation *callOp = use.getUser();
69+
// If the callOp has users then replace them with Undef values.
70+
if (!callOp->use_empty()) {
71+
SmallVector<Value> undefResults;
72+
for (Value res : callOp->getResults()) {
73+
opBuilder.setInsertionPoint(callOp);
74+
undefResults.emplace_back(
75+
opBuilder.create<fir::UndefOp>(res.getLoc(), res.getType()));
76+
}
77+
callOp->replaceAllUsesWith(undefResults);
78+
}
79+
// Remove the callOp
80+
callOp->erase();
81+
}
82+
if (!hasTargetRegion)
83+
funcOp.erase();
84+
else if (declareTargetOp)
7185
declareTargetOp.setDeclareTarget(declareType,
7286
omp::DeclareTargetCaptureClause::to);
7387
}

flang/test/Lower/OpenMP/FIR/array-bounds.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end subroutine read_write_section
3232
module assumed_array_routines
3333
contains
3434
!ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array(
35-
!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
35+
!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
3636
!ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_shape_arrayEi"}
3737
!ALL: %[[C0:.*]] = arith.constant 1 : index
3838
!ALL: %[[C1:.*]] = arith.constant 0 : index
@@ -56,7 +56,7 @@ subroutine assumed_shape_array(arr_read_write)
5656
end subroutine assumed_shape_array
5757

5858
!ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_size_array(
59-
!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
59+
!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
6060
!ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_size_arrayEi"}
6161
!ALL: %[[C0:.*]] = arith.constant 1 : index
6262
!ALL: %[[C1:.*]] = arith.constant 1 : index

flang/test/Lower/OpenMP/function-filtering.f90

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ end function device_fn
2121

2222
! MLIR-HOST: func.func @{{.*}}host_fn(
2323
! MLIR-HOST: return
24-
! MLIR-DEVICE: func.func private @{{.*}}host_fn(
25-
! MLIR-DEVICE-NOT: return
24+
! MLIR-DEVICE-NOT: func.func {{.*}}host_fn(
2625

2726
! LLVM-HOST: define {{.*}} @{{.*}}host_fn{{.*}}(
2827
! LLVM-DEVICE-NOT: {{.*}} @{{.*}}host_fn{{.*}}(
@@ -32,9 +31,8 @@ function host_fn() result(x)
3231
x = 10
3332
end function host_fn
3433

35-
! MLIR-HOST: func.func @{{.*}}target_subr(
36-
! MLIR-HOST: return
37-
! MLIR-DEVICE: return
34+
! MLIR-ALL: func.func @{{.*}}target_subr(
35+
! MLIR-ALL: return
3836

3937
! LLVM-HOST: define {{.*}} @{{.*}}target_subr{{.*}}(
4038
! LLVM-ALL: define {{.*}} @__omp_offloading_{{.*}}_{{.*}}_target_subr__{{.*}}(

flang/test/Transforms/omp-function-filtering.mlir

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44
// CHECK: return
55
// CHECK: func.func @nohost
66
// CHECK: return
7-
// CHECK: func.func private @host
8-
// CHECK-NOT: return
9-
// CHECK: func.func private @none
10-
// CHECK-NOT: return
7+
// CHECK-NOT: func.func {{.*}}}} @host
8+
// CHECK-NOT: func.func {{.*}}}} @none
119
// CHECK: func.func @nohost_target
1210
// CHECK: return
1311
// CHECK: func.func @host_target
1412
// CHECK: return
1513
// CHECK: func.func @none_target
1614
// CHECK: return
15+
// CHECK: func.func @host_target_call
16+
// CHECK-NOT: call @none_target
17+
// CHECK: %[[UNDEF:.*]] = fir.undefined i32
18+
// CHECK: return %[[UNDEF]] : i32
1719
module attributes {omp.is_target_device = true} {
1820
func.func @any() -> ()
1921
attributes {
@@ -55,9 +57,19 @@ module attributes {omp.is_target_device = true} {
5557
omp.target {}
5658
func.return
5759
}
58-
func.func @none_target() -> () {
60+
func.func @none_target() -> i32 {
5961
omp.target {}
60-
func.return
62+
%0 = arith.constant 25 : i32
63+
func.return %0 : i32
64+
}
65+
func.func @host_target_call() -> i32
66+
attributes {
67+
omp.declare_target =
68+
#omp.declaretarget<device_type = (host), capture_clause = (to)>
69+
} {
70+
omp.target {}
71+
%0 = call @none_target() : () -> i32
72+
func.return %0 : i32
6173
}
6274
}
6375

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,8 +2496,6 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
24962496
if (declareType == omp::DeclareTargetDeviceType::host) {
24972497
llvm::Function *llvmFunc =
24982498
moduleTranslation.lookupFunction(funcOp.getName());
2499-
llvmFunc->replaceAllUsesWith(
2500-
llvm::UndefValue::get(llvmFunc->getType()));
25012499
llvmFunc->dropAllReferences();
25022500
llvmFunc->eraseFromParent();
25032501
}

0 commit comments

Comments
 (0)