Skip to content

[flang] Improve designate/elemental indices match in opt-bufferization. #121371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class ElementalAssignBufferization
/// determines if the transformation can be applied to this elemental
static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);

/// Returns the array indices for the given hlfir.designate.
/// It recognizes the computations used to transform the one-based indices
/// into the array's lb-based indices, and returns the one-based indices
/// in these cases.
static llvm::SmallVector<mlir::Value>
getDesignatorIndices(hlfir::DesignateOp designate);

public:
using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;

Expand Down Expand Up @@ -430,6 +437,73 @@ bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
return false;
}

llvm::SmallVector<mlir::Value>
ElementalAssignBufferization::getDesignatorIndices(
hlfir::DesignateOp designate) {
mlir::Value memref = designate.getMemref();

// If the object is a box, then the indices may be adjusted
// according to the box's lower bound(s). Scan through
// the computations to try to find the one-based indices.
if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
// Look for the following pattern:
// %13 = fir.load %12 : !fir.ref<!fir.box<...>
// %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
// %17 = arith.subi %14#0, %c1 : index
// %18 = arith.addi %arg2, %17 : index
// %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ...
//
// %arg2 is a one-based index.

auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
// Return true, if v and dim are such that:
// %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
// %17 = arith.subi %14#0, %c1 : index
// %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ...
if (auto subOp =
mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
auto cst = fir::getIntIfConstant(subOp.getRhs());
if (!cst || *cst != 1)
return false;
if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
subOp.getLhs().getDefiningOp())) {
if (memref != dimsOp.getVal() ||
dimsOp.getResult(0) != subOp.getLhs())
return false;
auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
return dimsOpDim && dimsOpDim == dim;
}
}
return false;
};

llvm::SmallVector<mlir::Value> newIndices;
for (auto index : llvm::enumerate(designate.getIndices())) {
if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
index.value().getDefiningOp())) {
for (unsigned opNum = 0; opNum < 2; ++opNum)
if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
break;
}

// If new one-based index was not added, exit early.
if (newIndices.size() <= index.index())
break;
}
}

// If any of the indices is not adjusted to the array's lb,
// then return the original designator indices.
if (newIndices.size() != designate.getIndices().size())
return designate.getIndices();

return newIndices;
}

return designate.getIndices();
}

std::optional<ElementalAssignBufferization::MatchInfo>
ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
mlir::Operation::user_range users = elemental->getUsers();
Expand Down Expand Up @@ -557,7 +631,7 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
<< " at " << elemental.getLoc() << "\n");
return std::nullopt;
}
auto indices = designate.getIndices();
auto indices = getDesignatorIndices(designate);
auto elementalIndices = elemental.getIndices();
if (indices.size() == elementalIndices.size() &&
std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
Expand Down
69 changes: 69 additions & 0 deletions flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// RUN: fir-opt --opt-bufferization %s | FileCheck %s

// Verify that the hlfir.assign of hlfir.elemental is optimized
// into element-per-element assignment:
// subroutine test1(p)
// real, pointer :: p(:)
// p = p + 1.0
// end subroutine test1

func.func @_QPtest1(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "p"}) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 1.000000e+00 : f32
%0 = fir.dummy_scope : !fir.dscope
%1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest1Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
%2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
%3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> (index, index, index)
%4 = fir.shape %3#1 : (index) -> !fir.shape<1>
%5 = hlfir.elemental %4 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
^bb0(%arg1: index):
%6 = arith.subi %3#0, %c1 : index
%7 = arith.addi %arg1, %6 : index
%8 = hlfir.designate %2 (%7) : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
%9 = fir.load %8 : !fir.ref<f32>
%10 = arith.addf %9, %cst fastmath<contract> : f32
hlfir.yield_element %10 : f32
}
hlfir.assign %5 to %2 : !hlfir.expr<?xf32>, !fir.box<!fir.ptr<!fir.array<?xf32>>>
hlfir.destroy %5 : !hlfir.expr<?xf32>
return
}
// CHECK-LABEL: func.func @_QPtest1(
// CHECK-NOT: hlfir.assign
// CHECK: hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
// CHECK-NOT: hlfir.assign

// subroutine test2(p)
// real, pointer :: p(:,:)
// p = p + 1.0
// end subroutine test2
func.func @_QPtest2(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> {fir.bindc_name = "p"}) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 1.000000e+00 : f32
%0 = fir.dummy_scope : !fir.dscope
%1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest2Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>)
%2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
%4:3 = fir.box_dims %2, %c1 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
%5 = fir.shape %3#1, %4#1 : (index, index) -> !fir.shape<2>
%6 = hlfir.elemental %5 unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
^bb0(%arg1: index, %arg2: index):
%7 = arith.subi %3#0, %c1 : index
%8 = arith.addi %arg1, %7 : index
%9 = arith.subi %4#0, %c1 : index
%10 = arith.addi %arg2, %9 : index
%11 = hlfir.designate %2 (%8, %10) : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index, index) -> !fir.ref<f32>
%12 = fir.load %11 : !fir.ref<f32>
%13 = arith.addf %12, %cst fastmath<contract> : f32
hlfir.yield_element %13 : f32
}
hlfir.assign %6 to %2 : !hlfir.expr<?x?xf32>, !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
hlfir.destroy %6 : !hlfir.expr<?x?xf32>
return
}
// CHECK-LABEL: func.func @_QPtest2(
// CHECK-NOT: hlfir.assign
// CHECK: hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
// CHECK-NOT: hlfir.assign
Loading