Skip to content

Commit b9482ce

Browse files
authored
[flang] Improve designate/elemental indices match in opt-bufferization. (#121371)
This pattern appears in `tonto`: `rys1%w = rys1%w * ...`, where component `w` is a pointer. Due to the computations transforming the elemental's one-based indices to the array indices, the indices match check did not pass in opt-bufferization. This patch recognizes this indices adjusting pattern, and returns the one-based indices for the designator.
1 parent 5137c20 commit b9482ce

File tree

2 files changed

+144
-1
lines changed

2 files changed

+144
-1
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ class ElementalAssignBufferization
8787
/// determines if the transformation can be applied to this elemental
8888
static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
8989

90+
/// Returns the array indices for the given hlfir.designate.
91+
/// It recognizes the computations used to transform the one-based indices
92+
/// into the array's lb-based indices, and returns the one-based indices
93+
/// in these cases.
94+
static llvm::SmallVector<mlir::Value>
95+
getDesignatorIndices(hlfir::DesignateOp designate);
96+
9097
public:
9198
using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
9299

@@ -430,6 +437,73 @@ bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
430437
return false;
431438
}
432439

440+
llvm::SmallVector<mlir::Value>
441+
ElementalAssignBufferization::getDesignatorIndices(
442+
hlfir::DesignateOp designate) {
443+
mlir::Value memref = designate.getMemref();
444+
445+
// If the object is a box, then the indices may be adjusted
446+
// according to the box's lower bound(s). Scan through
447+
// the computations to try to find the one-based indices.
448+
if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
449+
// Look for the following pattern:
450+
// %13 = fir.load %12 : !fir.ref<!fir.box<...>
451+
// %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
452+
// %17 = arith.subi %14#0, %c1 : index
453+
// %18 = arith.addi %arg2, %17 : index
454+
// %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ...
455+
//
456+
// %arg2 is a one-based index.
457+
458+
auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
459+
// Return true, if v and dim are such that:
460+
// %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
461+
// %17 = arith.subi %14#0, %c1 : index
462+
// %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ...
463+
if (auto subOp =
464+
mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
465+
auto cst = fir::getIntIfConstant(subOp.getRhs());
466+
if (!cst || *cst != 1)
467+
return false;
468+
if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
469+
subOp.getLhs().getDefiningOp())) {
470+
if (memref != dimsOp.getVal() ||
471+
dimsOp.getResult(0) != subOp.getLhs())
472+
return false;
473+
auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
474+
return dimsOpDim && dimsOpDim == dim;
475+
}
476+
}
477+
return false;
478+
};
479+
480+
llvm::SmallVector<mlir::Value> newIndices;
481+
for (auto index : llvm::enumerate(designate.getIndices())) {
482+
if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
483+
index.value().getDefiningOp())) {
484+
for (unsigned opNum = 0; opNum < 2; ++opNum)
485+
if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
486+
newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
487+
break;
488+
}
489+
490+
// If new one-based index was not added, exit early.
491+
if (newIndices.size() <= index.index())
492+
break;
493+
}
494+
}
495+
496+
// If any of the indices is not adjusted to the array's lb,
497+
// then return the original designator indices.
498+
if (newIndices.size() != designate.getIndices().size())
499+
return designate.getIndices();
500+
501+
return newIndices;
502+
}
503+
504+
return designate.getIndices();
505+
}
506+
433507
std::optional<ElementalAssignBufferization::MatchInfo>
434508
ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
435509
mlir::Operation::user_range users = elemental->getUsers();
@@ -557,7 +631,7 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
557631
<< " at " << elemental.getLoc() << "\n");
558632
return std::nullopt;
559633
}
560-
auto indices = designate.getIndices();
634+
auto indices = getDesignatorIndices(designate);
561635
auto elementalIndices = elemental.getIndices();
562636
if (indices.size() == elementalIndices.size() &&
563637
std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// RUN: fir-opt --opt-bufferization %s | FileCheck %s
2+
3+
// Verify that the hlfir.assign of hlfir.elemental is optimized
4+
// into element-per-element assignment:
5+
// subroutine test1(p)
6+
// real, pointer :: p(:)
7+
// p = p + 1.0
8+
// end subroutine test1
9+
10+
func.func @_QPtest1(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "p"}) {
11+
%c1 = arith.constant 1 : index
12+
%c0 = arith.constant 0 : index
13+
%cst = arith.constant 1.000000e+00 : f32
14+
%0 = fir.dummy_scope : !fir.dscope
15+
%1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest1Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
16+
%2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
17+
%3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> (index, index, index)
18+
%4 = fir.shape %3#1 : (index) -> !fir.shape<1>
19+
%5 = hlfir.elemental %4 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
20+
^bb0(%arg1: index):
21+
%6 = arith.subi %3#0, %c1 : index
22+
%7 = arith.addi %arg1, %6 : index
23+
%8 = hlfir.designate %2 (%7) : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
24+
%9 = fir.load %8 : !fir.ref<f32>
25+
%10 = arith.addf %9, %cst fastmath<contract> : f32
26+
hlfir.yield_element %10 : f32
27+
}
28+
hlfir.assign %5 to %2 : !hlfir.expr<?xf32>, !fir.box<!fir.ptr<!fir.array<?xf32>>>
29+
hlfir.destroy %5 : !hlfir.expr<?xf32>
30+
return
31+
}
32+
// CHECK-LABEL: func.func @_QPtest1(
33+
// CHECK-NOT: hlfir.assign
34+
// CHECK: hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
35+
// CHECK-NOT: hlfir.assign
36+
37+
// subroutine test2(p)
38+
// real, pointer :: p(:,:)
39+
// p = p + 1.0
40+
// end subroutine test2
41+
func.func @_QPtest2(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> {fir.bindc_name = "p"}) {
42+
%c1 = arith.constant 1 : index
43+
%c0 = arith.constant 0 : index
44+
%cst = arith.constant 1.000000e+00 : f32
45+
%0 = fir.dummy_scope : !fir.dscope
46+
%1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest2Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>)
47+
%2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
48+
%3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
49+
%4:3 = fir.box_dims %2, %c1 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
50+
%5 = fir.shape %3#1, %4#1 : (index, index) -> !fir.shape<2>
51+
%6 = hlfir.elemental %5 unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
52+
^bb0(%arg1: index, %arg2: index):
53+
%7 = arith.subi %3#0, %c1 : index
54+
%8 = arith.addi %arg1, %7 : index
55+
%9 = arith.subi %4#0, %c1 : index
56+
%10 = arith.addi %arg2, %9 : index
57+
%11 = hlfir.designate %2 (%8, %10) : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index, index) -> !fir.ref<f32>
58+
%12 = fir.load %11 : !fir.ref<f32>
59+
%13 = arith.addf %12, %cst fastmath<contract> : f32
60+
hlfir.yield_element %13 : f32
61+
}
62+
hlfir.assign %6 to %2 : !hlfir.expr<?x?xf32>, !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
63+
hlfir.destroy %6 : !hlfir.expr<?x?xf32>
64+
return
65+
}
66+
// CHECK-LABEL: func.func @_QPtest2(
67+
// CHECK-NOT: hlfir.assign
68+
// CHECK: hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
69+
// CHECK-NOT: hlfir.assign

0 commit comments

Comments
 (0)