Skip to content

Commit 0ac8cb1

Browse files
authored
[flang] Recognize fir.pack_array in LoopVersioning. (#133191)
This change enables LoopVersioning when `fir.pack_array` is met in the def-use chain. It fixes a couple of huge performance regressions caused by enabling `-frepack-arrays`.
1 parent c1bf5e6 commit 0ac8cb1

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

flang/docs/ArrayRepacking.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,8 @@ There is an existing optimization pass (controlled via `-f[no-]version-loops-for
432432
433433
The array repacking is targeting better data cache utilization, and is not intended to enable more unit-strided vectorization for the assumed-shape arrays. At the same time, combining array repacking with the loop versioning may provide better performance for programs where the actual array arguments are non-contiguous, but then their repacked copies can be accessed using unit strides.
434434
435+
It is suggested that the LoopVersioning pass is run before the lowering of `fir.pack_array` and `fir.unpack_array` operations, and recognizes `fir.pack_array` on the path from `fir.declare` to the function entry block argument. The pass generates the dynamic contiguity checks, and multiversions the loops. In case the repacking actually happens, the most optimal versions of the loops are executed.
436+
435437
In cases where `fir.pack_array` is statically known to produce a copy that is contiguous in the innermost dimension, the loop versioning pass can skip the generation of the dynamic checks.
436438
437439
### Driver: user options

flang/lib/Optimizer/Transforms/LoopVersioning.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,28 @@ getRankAndElementSize(const fir::KindMapping &kindMap,
184184
return {0, 0};
185185
}
186186

187-
/// if a value comes from a fir.declare, follow it to the original source,
188-
/// otherwise return the value
189-
static mlir::Value unwrapFirDeclare(mlir::Value val) {
190-
// fir.declare is for source code variables. We don't have declares of
191-
// declares
192-
if (fir::DeclareOp declare = val.getDefiningOp<fir::DeclareOp>())
193-
return declare.getMemref();
187+
/// If a value comes from a fir.declare of fir.pack_array,
188+
/// follow it to the original source, otherwise return the value.
189+
static mlir::Value unwrapPassThroughOps(mlir::Value val) {
190+
// Instead of unwrapping fir.declare, we may try to start
191+
// the analysis in this pass from fir.declare's instead
192+
// of the function entry block arguments. This way the loop
193+
// versioning would work even after FIR inlining.
194+
while (true) {
195+
if (fir::DeclareOp declare = val.getDefiningOp<fir::DeclareOp>()) {
196+
val = declare.getMemref();
197+
continue;
198+
}
199+
// fir.pack_array might be met before fir.declare - this is how
200+
// it is orifinally generated.
201+
// It might also be met after fir.declare - after the optimization
202+
// passes that sink fir.pack_array closer to the uses.
203+
if (auto packArray = val.getDefiningOp<fir::PackArrayOp>()) {
204+
val = packArray.getArray();
205+
continue;
206+
}
207+
break;
208+
}
194209
return val;
195210
}
196211

@@ -242,7 +257,7 @@ static mlir::Value unwrapReboxOp(mlir::Value val) {
242257
/// normalize a value (removing fir.declare and fir.rebox) so that we can
243258
/// more conveniently spot values which came from function arguments
244259
static mlir::Value normaliseVal(mlir::Value val) {
245-
return unwrapFirDeclare(unwrapReboxOp(val));
260+
return unwrapPassThroughOps(unwrapReboxOp(val));
246261
}
247262

248263
/// some FIR operations accept a fir.shape, a fir.shift or a fir.shapeshift.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: fir-opt --loop-versioning %s | FileCheck %s
2+
3+
// Check that LoopVersioning kicks in when there is fir.pack_array
4+
// in between fir.declare and the block argument.
5+
6+
module attributes {dlti.dl_spec = #dlti.dl_spec<>} {
7+
func.func @_QPtest(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
8+
%c1 = arith.constant 1 : index
9+
%c0 = arith.constant 0 : index
10+
%cst = arith.constant 1.000000e+00 : f32
11+
%0 = fir.dummy_scope : !fir.dscope
12+
%1 = fir.pack_array %arg0 stack whole : (!fir.box<!fir.array<?xf32>>) -> !fir.box<!fir.array<?xf32>>
13+
%2 = fir.declare %1 dummy_scope %0 {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> !fir.box<!fir.array<?xf32>>
14+
%3 = fir.rebox %2 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<!fir.array<?xf32>>
15+
%4:3 = fir.box_dims %3, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
16+
fir.do_loop %arg1 = %c1 to %4#1 step %c1 unordered {
17+
%5 = fir.array_coor %2 %arg1 : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
18+
fir.store %cst to %5 : !fir.ref<f32>
19+
}
20+
fir.unpack_array %1 to %arg0 stack : !fir.box<!fir.array<?xf32>>
21+
return
22+
}
23+
}
24+
// CHECK-LABEL: func.func @_QPtest(
25+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
26+
// CHECK: %[[VAL_5:.*]] = fir.pack_array %[[VAL_0]] stack whole : (!fir.box<!fir.array<?xf32>>) -> !fir.box<!fir.array<?xf32>>
27+
// CHECK: %[[VAL_6:.*]] = fir.declare %[[VAL_5]] dummy_scope %{{.*}} {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> !fir.box<!fir.array<?xf32>>
28+
// CHECK: %[[VAL_10:.*]]:3 = fir.box_dims %[[VAL_6]], %{{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
29+
// CHECK: %[[VAL_11:.*]] = arith.constant 4 : index
30+
// CHECK: %[[VAL_12:.*]] = arith.cmpi eq, %[[VAL_10]]#2, %[[VAL_11]] : index
31+
// CHECK: fir.if %[[VAL_12]] {
32+
// CHECK: fir.do_loop {{.*}} {
33+
// CHECK: fir.coordinate_of {{.*}} : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
34+
// CHECK: }
35+
// CHECK: } else {
36+
// CHECK: fir.do_loop {{.*}} {
37+
// CHECK: fir.array_coor {{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
38+
// CHECK: }
39+
// CHECK: }
40+
// CHECK: fir.unpack_array %[[VAL_5]] to %[[VAL_0]] stack : !fir.box<!fir.array<?xf32>>

0 commit comments

Comments
 (0)