Skip to content

[flang] Recognize fir.pack_array in LoopVersioning. #133191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flang/docs/ArrayRepacking.md
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ There is an existing optimization pass (controlled via `-f[no-]version-loops-for

The array repacking is targeting better data cache utilization, and is not intended to enable more unit-strided vectorization for the assumed-shape arrays. At the same time, combining array repacking with the loop versioning may provide better performance for programs where the actual array arguments are non-contiguous, but then their repacked copies can be accessed using unit strides.

It is suggested that the LoopVersioning pass is run before the lowering of `fir.pack_array` and `fir.unpack_array` operations, and recognizes `fir.pack_array` on the path from `fir.declare` to the function entry block argument. The pass generates the dynamic contiguity checks, and multiversions the loops. In case the repacking actually happens, the most optimal versions of the loops are executed.

In cases where `fir.pack_array` is statically known to produce a copy that is contiguous in the innermost dimension, the loop versioning pass can skip the generation of the dynamic checks.

### Driver: user options
Expand Down
31 changes: 23 additions & 8 deletions flang/lib/Optimizer/Transforms/LoopVersioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,28 @@ getRankAndElementSize(const fir::KindMapping &kindMap,
return {0, 0};
}

/// if a value comes from a fir.declare, follow it to the original source,
/// otherwise return the value
static mlir::Value unwrapFirDeclare(mlir::Value val) {
// fir.declare is for source code variables. We don't have declares of
// declares
if (fir::DeclareOp declare = val.getDefiningOp<fir::DeclareOp>())
return declare.getMemref();
/// If a value comes from a fir.declare of fir.pack_array,
/// follow it to the original source, otherwise return the value.
static mlir::Value unwrapPassThroughOps(mlir::Value val) {
// Instead of unwrapping fir.declare, we may try to start
// the analysis in this pass from fir.declare's instead
// of the function entry block arguments. This way the loop
// versioning would work even after FIR inlining.
while (true) {
if (fir::DeclareOp declare = val.getDefiningOp<fir::DeclareOp>()) {
val = declare.getMemref();
continue;
}
// fir.pack_array might be met before fir.declare - this is how
// it is orifinally generated.
// It might also be met after fir.declare - after the optimization
// passes that sink fir.pack_array closer to the uses.
if (auto packArray = val.getDefiningOp<fir::PackArrayOp>()) {
val = packArray.getArray();
continue;
}
break;
}
return val;
}

Expand Down Expand Up @@ -242,7 +257,7 @@ static mlir::Value unwrapReboxOp(mlir::Value val) {
/// normalize a value (removing fir.declare and fir.rebox) so that we can
/// more conveniently spot values which came from function arguments
static mlir::Value normaliseVal(mlir::Value val) {
return unwrapFirDeclare(unwrapReboxOp(val));
return unwrapPassThroughOps(unwrapReboxOp(val));
}

/// some FIR operations accept a fir.shape, a fir.shift or a fir.shapeshift.
Expand Down
40 changes: 40 additions & 0 deletions flang/test/Transforms/loop-versioning-with-repack-arrays.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: fir-opt --loop-versioning %s | FileCheck %s

// Check that LoopVersioning kicks in when there is fir.pack_array
// in between fir.declare and the block argument.

module attributes {dlti.dl_spec = #dlti.dl_spec<>} {
func.func @_QPtest(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 1.000000e+00 : f32
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.pack_array %arg0 stack whole : (!fir.box<!fir.array<?xf32>>) -> !fir.box<!fir.array<?xf32>>
%2 = fir.declare %1 dummy_scope %0 {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> !fir.box<!fir.array<?xf32>>
%3 = fir.rebox %2 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<!fir.array<?xf32>>
%4:3 = fir.box_dims %3, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
fir.do_loop %arg1 = %c1 to %4#1 step %c1 unordered {
%5 = fir.array_coor %2 %arg1 : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
fir.store %cst to %5 : !fir.ref<f32>
}
fir.unpack_array %1 to %arg0 stack : !fir.box<!fir.array<?xf32>>
return
}
}
// CHECK-LABEL: func.func @_QPtest(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
// CHECK: %[[VAL_5:.*]] = fir.pack_array %[[VAL_0]] stack whole : (!fir.box<!fir.array<?xf32>>) -> !fir.box<!fir.array<?xf32>>
// CHECK: %[[VAL_6:.*]] = fir.declare %[[VAL_5]] dummy_scope %{{.*}} {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> !fir.box<!fir.array<?xf32>>
// CHECK: %[[VAL_10:.*]]:3 = fir.box_dims %[[VAL_6]], %{{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
// CHECK: %[[VAL_11:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_12:.*]] = arith.cmpi eq, %[[VAL_10]]#2, %[[VAL_11]] : index
// CHECK: fir.if %[[VAL_12]] {
// CHECK: fir.do_loop {{.*}} {
// CHECK: fir.coordinate_of {{.*}} : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK: }
// CHECK: } else {
// CHECK: fir.do_loop {{.*}} {
// CHECK: fir.array_coor {{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK: }
// CHECK: }
// CHECK: fir.unpack_array %[[VAL_5]] to %[[VAL_0]] stack : !fir.box<!fir.array<?xf32>>