Skip to content

Commit e0ef194

Browse files
committed
Add workshare loop wrapper lowerings
Bufferize test Bufferize test Bufferize test Add test for should use workshare lowering
1 parent e56dbd6 commit e0ef194

File tree

4 files changed

+208
-4
lines changed

4 files changed

+208
-4
lines changed

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
2727
#include "flang/Optimizer/HLFIR/HLFIROps.h"
2828
#include "flang/Optimizer/HLFIR/Passes.h"
29+
#include "flang/Optimizer/OpenMP/Passes.h"
2930
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3031
#include "mlir/IR/Dominance.h"
3132
#include "mlir/IR/PatternMatch.h"
@@ -792,7 +793,8 @@ struct ElementalOpConversion
792793
// Generate a loop nest looping around the fir.elemental shape and clone
793794
// fir.elemental region inside the inner loop.
794795
hlfir::LoopNest loopNest =
795-
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
796+
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
797+
flangomp::shouldUseWorkshareLowering(elemental));
796798
auto insPt = builder.saveInsertionPoint();
797799
builder.setInsertionPointToStart(loopNest.body);
798800
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
2121
#include "flang/Optimizer/HLFIR/HLFIROps.h"
2222
#include "flang/Optimizer/HLFIR/Passes.h"
23+
#include "flang/Optimizer/OpenMP/Passes.h"
2324
#include "flang/Optimizer/Transforms/Utils.h"
2425
#include "mlir/Dialect/Func/IR/FuncOps.h"
2526
#include "mlir/IR/Dominance.h"
@@ -482,7 +483,8 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
482483
// Generate a loop nest looping around the hlfir.elemental shape and clone
483484
// hlfir.elemental region inside the inner loop
484485
hlfir::LoopNest loopNest =
485-
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
486+
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
487+
flangomp::shouldUseWorkshareLowering(elemental));
486488
builder.setInsertionPointToStart(loopNest.body);
487489
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
488490
loopNest.oneBasedIndices);
@@ -553,7 +555,8 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
553555
llvm::SmallVector<mlir::Value> extents =
554556
hlfir::getIndexExtents(loc, builder, shape);
555557
hlfir::LoopNest loopNest =
556-
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
558+
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
559+
flangomp::shouldUseWorkshareLowering(assign));
557560
builder.setInsertionPointToStart(loopNest.body);
558561
auto arrayElement =
559562
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
@@ -648,7 +651,8 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
648651
llvm::SmallVector<mlir::Value> extents =
649652
hlfir::getIndexExtents(loc, builder, shape);
650653
hlfir::LoopNest loopNest =
651-
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
654+
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
655+
flangomp::shouldUseWorkshareLowering(assign));
652656
builder.setInsertionPointToStart(loopNest.body);
653657
auto rhsArrayElement =
654658
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @simple(
4+
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
5+
// CHECK: omp.parallel {
6+
// CHECK: omp.workshare {
7+
// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
8+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
9+
// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
10+
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
11+
// CHECK: %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
12+
// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
13+
// CHECK: %[[VAL_7:.*]] = arith.constant true
14+
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
15+
// CHECK: omp.workshare.loop_wrapper {
16+
// CHECK: omp.loop_nest (%[[VAL_9:.*]]) : index = (%[[VAL_8]]) to (%[[VAL_1]]) inclusive step (%[[VAL_8]]) {
17+
// CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
18+
// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
19+
// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_11]], %[[VAL_2]] : i32
20+
// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
21+
// CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_13]] temporary_lhs : i32, !fir.ref<i32>
22+
// CHECK: omp.yield
23+
// CHECK: }
24+
// CHECK: omp.terminator
25+
// CHECK: }
26+
// CHECK: %[[VAL_14:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
27+
// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_7]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
28+
// CHECK: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_6]]#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
29+
// CHECK: hlfir.assign %[[VAL_6]]#0 to %[[VAL_4]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
30+
// CHECK: fir.freemem %[[VAL_6]]#0 : !fir.heap<!fir.array<42xi32>>
31+
// CHECK: omp.terminator
32+
// CHECK: }
33+
// CHECK: omp.terminator
34+
// CHECK: }
35+
// CHECK: return
36+
// CHECK: }
37+
func.func @simple(%arg: !fir.ref<!fir.array<42xi32>>) {
38+
omp.parallel {
39+
omp.workshare {
40+
%c42 = arith.constant 42 : index
41+
%c1_i32 = arith.constant 1 : i32
42+
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
43+
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
44+
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
45+
^bb0(%i: index):
46+
%ref = hlfir.designate %array#0 (%i) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
47+
%val = fir.load %ref : !fir.ref<i32>
48+
%sub = arith.subi %val, %c1_i32 : i32
49+
hlfir.yield_element %sub : i32
50+
}
51+
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
52+
hlfir.destroy %elemental : !hlfir.expr<42xi32>
53+
omp.terminator
54+
}
55+
omp.terminator
56+
}
57+
return
58+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s
2+
3+
// Checks that we correctly identify when to use the lowering to
4+
// omp.workshare.loop_wrapper
5+
6+
// CHECK-LABEL: @should_parallelize_0
7+
// CHECK: omp.workshare.loop_wrapper
8+
func.func @should_parallelize_0(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
9+
omp.workshare {
10+
%c42 = arith.constant 42 : index
11+
%c1_i32 = arith.constant 1 : i32
12+
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
13+
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
14+
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
15+
^bb0(%i: index):
16+
hlfir.yield_element %c1_i32 : i32
17+
}
18+
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
19+
hlfir.destroy %elemental : !hlfir.expr<42xi32>
20+
omp.terminator
21+
}
22+
return
23+
}
24+
25+
// CHECK-LABEL: @should_parallelize_1
26+
// CHECK: omp.workshare.loop_wrapper
27+
func.func @should_parallelize_1(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
28+
omp.parallel {
29+
omp.workshare {
30+
%c42 = arith.constant 42 : index
31+
%c1_i32 = arith.constant 1 : i32
32+
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
33+
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
34+
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
35+
^bb0(%i: index):
36+
hlfir.yield_element %c1_i32 : i32
37+
}
38+
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
39+
hlfir.destroy %elemental : !hlfir.expr<42xi32>
40+
omp.terminator
41+
}
42+
omp.terminator
43+
}
44+
return
45+
}
46+
47+
48+
// CHECK-LABEL: @should_not_parallelize_0
49+
// CHECK-NOT: omp.workshare.loop_wrapper
50+
func.func @should_not_parallelize_0(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
51+
omp.workshare {
52+
omp.single {
53+
%c42 = arith.constant 42 : index
54+
%c1_i32 = arith.constant 1 : i32
55+
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
56+
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
57+
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
58+
^bb0(%i: index):
59+
hlfir.yield_element %c1_i32 : i32
60+
}
61+
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
62+
hlfir.destroy %elemental : !hlfir.expr<42xi32>
63+
omp.terminator
64+
}
65+
omp.terminator
66+
}
67+
return
68+
}
69+
70+
// CHECK-LABEL: @should_not_parallelize_1
71+
// CHECK-NOT: omp.workshare.loop_wrapper
72+
func.func @should_not_parallelize_1(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
73+
omp.workshare {
74+
omp.critical {
75+
%c42 = arith.constant 42 : index
76+
%c1_i32 = arith.constant 1 : i32
77+
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
78+
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
79+
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
80+
^bb0(%i: index):
81+
hlfir.yield_element %c1_i32 : i32
82+
}
83+
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
84+
hlfir.destroy %elemental : !hlfir.expr<42xi32>
85+
omp.terminator
86+
}
87+
omp.terminator
88+
}
89+
return
90+
}
91+
92+
// CHECK-LABEL: @should_not_parallelize_2
93+
// CHECK-NOT: omp.workshare.loop_wrapper
94+
func.func @should_not_parallelize_2(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
95+
omp.workshare {
96+
omp.parallel {
97+
%c42 = arith.constant 42 : index
98+
%c1_i32 = arith.constant 1 : i32
99+
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
100+
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
101+
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
102+
^bb0(%i: index):
103+
hlfir.yield_element %c1_i32 : i32
104+
}
105+
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
106+
hlfir.destroy %elemental : !hlfir.expr<42xi32>
107+
omp.terminator
108+
}
109+
omp.terminator
110+
}
111+
return
112+
}
113+
114+
// CHECK-LABEL: @should_not_parallelize_3
115+
// CHECK-NOT: omp.workshare.loop_wrapper
116+
func.func @should_not_parallelize_3(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
117+
omp.workshare {
118+
omp.parallel {
119+
omp.workshare {
120+
omp.parallel {
121+
%c42 = arith.constant 42 : index
122+
%c1_i32 = arith.constant 1 : i32
123+
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
124+
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
125+
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
126+
^bb0(%i: index):
127+
hlfir.yield_element %c1_i32 : i32
128+
}
129+
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
130+
hlfir.destroy %elemental : !hlfir.expr<42xi32>
131+
omp.terminator
132+
}
133+
omp.terminator
134+
}
135+
omp.terminator
136+
}
137+
omp.terminator
138+
}
139+
return
140+
}

0 commit comments

Comments
 (0)