Skip to content

[Flang] Maxloc elemental intrinsic lowering. #79469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,54 +812,59 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
// inlined elemental.
// %e = hlfir.elemental %shape ({ ... })
// %m = hlfir.minloc %array mask %e
class MinMaxlocElementalConversion
: public mlir::OpRewritePattern<hlfir::MinlocOp> {
template <typename Op>
class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
public:
using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;
using mlir::OpRewritePattern<Op>::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(hlfir::MinlocOp minloc,
mlir::PatternRewriter &rewriter) const override {
if (!minloc.getMask() || minloc.getDim() || minloc.getBack())
return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc");
matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
return rewriter.notifyMatchFailure(mloc,
"Did not find valid minloc/maxloc");

auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
constexpr bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;

auto elemental =
mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
return rewriter.notifyMatchFailure(minloc, "Did not find elemental");
return rewriter.notifyMatchFailure(mloc, "Did not find elemental");

mlir::Value array = minloc.getArray();
mlir::Value array = mloc.getArray();

unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
mlir::Type arrayType = array.getType();
if (!arrayType.isa<fir::BoxType>())
return rewriter.notifyMatchFailure(
minloc, "Currently requires a boxed type input");
mloc, "Currently requires a boxed type input");
mlir::Type elementType = hlfir::getFortranElementType(arrayType);
if (!fir::isa_trivial(elementType))
return rewriter.notifyMatchFailure(
minloc, "Character arrays are currently not handled");
mloc, "Character arrays are currently not handled");

mlir::Location loc = minloc.getLoc();
fir::FirOpBuilder builder{rewriter, minloc.getOperation()};
mlir::Location loc = mloc.getLoc();
fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
mlir::Value resultArr = builder.createTemporary(
loc, fir::SequenceType::get(
rank, hlfir::getFortranElementType(minloc.getType())));
rank, hlfir::getFortranElementType(mloc.getType())));

auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, elementType,
llvm::APFloat::getLargest(sem, /*Negative=*/false));
llvm::APFloat::getLargest(sem, /*Negative=*/!isMax));
}
unsigned bits = elementType.getIntOrFloatBitWidth();
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, elementType, maxInt);
int64_t limitInt =
isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
: llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, elementType, limitInt);
};

auto genBodyOp =
[&rank, &resultArr, &elemental](
[&rank, &resultArr, &elemental, isMax](
fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
mlir::Value reduction,
Expand Down Expand Up @@ -899,10 +904,16 @@ class MinMaxlocElementalConversion
mlir::Value cmp;
if (elementType.isa<mlir::FloatType>()) {
cmp = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OLT, elem, reduction);
loc,
isMax ? mlir::arith::CmpFPredicate::OGT
: mlir::arith::CmpFPredicate::OLT,
elem, reduction);
} else if (elementType.isa<mlir::IntegerType>()) {
cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, elem, reduction);
loc,
isMax ? mlir::arith::CmpIPredicate::sgt
: mlir::arith::CmpIPredicate::slt,
elem, reduction);
} else {
llvm_unreachable("unsupported type");
}
Expand Down Expand Up @@ -975,15 +986,15 @@ class MinMaxlocElementalConversion
// AsExpr for the temporary resultArr.
llvm::SmallVector<hlfir::DestroyOp> destroys;
llvm::SmallVector<hlfir::AssignOp> assigns;
for (auto user : minloc->getUsers()) {
for (auto user : mloc->getUsers()) {
if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
destroys.push_back(destroy);
else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
assigns.push_back(assign);
}

// Check if the minloc was the only user of the elemental (apart from a
// destroy), and remove it if so.
// Check if the minloc/maxloc was the only user of the elemental (apart from
// a destroy), and remove it if so.
mlir::Operation::user_range elemUsers = elemental->getUsers();
hlfir::DestroyOp elemDestroy;
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
Expand All @@ -996,7 +1007,7 @@ class MinMaxlocElementalConversion
rewriter.eraseOp(d);
for (auto a : assigns)
a.setOperand(0, resultArr);
rewriter.replaceOp(minloc, asExpr);
rewriter.replaceOp(mloc, asExpr);
if (elemDestroy) {
rewriter.eraseOp(elemDestroy);
rewriter.eraseOp(elemental);
Expand Down Expand Up @@ -1030,7 +1041,8 @@ class OptimizedBufferizationPass
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
patterns.insert<MinMaxlocElementalConversion>(context);
patterns.insert<MinMaxlocElementalConversion<hlfir::MinlocOp>>(context);
patterns.insert<MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
func, std::move(patterns), config))) {
Expand Down
140 changes: 140 additions & 0 deletions flang/test/HLFIR/maxloc-elemental.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// RUN: fir-opt %s -opt-bufferization | FileCheck %s

func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
%c0 = arith.constant 0 : index
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
%1:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%3 = fir.load %2#0 : !fir.ref<i32>
%4:3 = fir.box_dims %0#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
%5 = fir.shape %4#1 : (index) -> !fir.shape<1>
%6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
^bb0(%arg3: index):
%8 = hlfir.designate %0#0 (%arg3) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
%9 = fir.load %8 : !fir.ref<i32>
%10 = arith.cmpi sge, %9, %3 : i32
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
hlfir.yield_element %11 : !fir.logical<4>
}
%7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box<!fir.array<?xi32>>
hlfir.destroy %7 : !hlfir.expr<1xi32>
hlfir.destroy %6 : !hlfir.expr<?x!fir.logical<4>>
return
}
// CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
// CHECK-NEXT: %c-2147483648_i32 = arith.constant -2147483648 : i32
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %[[V0:.*]] = fir.alloca i32
// CHECK-NEXT: %[[RES:.*]] = fir.alloca !fir.array<1xi32>
// CHECK-NEXT: %[[V1:.*]]:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
// CHECK-NEXT: %[[V2:.*]]:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
// CHECK-NEXT: %[[V3:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
// CHECK-NEXT: %[[V4:.*]] = fir.load %[[V3]]#0 : !fir.ref<i32>
// CHECK-NEXT: %[[V8:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: fir.store %c0_i32 to %[[V8]] : !fir.ref<i32>
// CHECK-NEXT: fir.store %c0_i32 to %[[V0]] : !fir.ref<i32>
// CHECK-NEXT: %[[V9:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK-NEXT: %[[V10:.*]] = arith.subi %[[V9]]#1, %c1 : index
// CHECK-NEXT: %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10]] step %c1 iter_args(%arg4 = %c-2147483648_i32) -> (i32) {
// CHECK-NEXT: %[[V14:.*]] = arith.addi %arg3, %c1 : index
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V1]]#0 (%[[V14]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<i32>
// CHECK-NEXT: %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32
// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (i32) {
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<i32>
// CHECK-NEXT: %[[V21:.*]] = arith.cmpi sgt, %[[V20]], %arg4 : i32
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (i32) {
// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
// CHECK-NEXT: fir.result %[[V20]] : i32
// CHECK-NEXT: } else {
// CHECK-NEXT: fir.result %arg4 : i32
// CHECK-NEXT: }
// CHECK-NEXT: fir.result %[[V22]] : i32
// CHECK-NEXT: } else {
// CHECK-NEXT: fir.result %arg4 : i32
// CHECK-NEXT: }
// CHECK-NEXT: fir.result %[[V18]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: %[[V12:.*]] = fir.load %[[V0]] : !fir.ref<i32>
// CHECK-NEXT: %[[V13:.*]] = arith.cmpi eq, %[[V12]], %c1_i32 : i32
// CHECK-NEXT: fir.if %[[V13]] {
// CHECK-NEXT: %[[V14:.*]] = arith.cmpi eq, %[[V11]], %c-2147483648_i32 : i32
// CHECK-NEXT: fir.if %[[V14]] {
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: fir.store %c1_i32 to %[[V15]] : !fir.ref<i32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: %[[BD:.*]]:3 = fir.box_dims %[[V2]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK-NEXT: fir.do_loop %arg3 = %c1 to %[[BD]]#1 step %c1 unordered {
// CHECK-NEXT: %[[V13:.*]] = hlfir.designate %[[RES]] (%arg3) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: %[[V14:.*]] = fir.load %[[V13]] : !fir.ref<i32>
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V2]]#0 (%arg3) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: hlfir.assign %[[V14]] to %[[V15]] : i32, !fir.ref<i32>
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }



func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<f32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
%c0 = arith.constant 0 : index
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xf32>>) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
%1:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
%3 = fir.load %2#0 : !fir.ref<f32>
%4:3 = fir.box_dims %0#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
%5 = fir.shape %4#1 : (index) -> !fir.shape<1>
%6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
^bb0(%arg3: index):
%8 = hlfir.designate %0#0 (%arg3) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
%9 = fir.load %8 : !fir.ref<f32>
%10 = arith.cmpf oge, %9, %3 : f32
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
hlfir.yield_element %11 : !fir.logical<4>
}
%7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xf32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box<!fir.array<?xi32>>
hlfir.destroy %7 : !hlfir.expr<1xi32>
hlfir.destroy %6 : !hlfir.expr<?x!fir.logical<4>>
return
}
// CHECK-LABEL: _QPtest_float
// CHECK: %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10:.*]] step %c1 iter_args(%arg4 = %cst) -> (f32) {
// CHECK-NEXT: %[[V14:.*]] = arith.addi %arg3, %c1 : index
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V1:.*]]#0 (%[[V14]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<f32>
// CHECK-NEXT: %[[V17:.*]] = arith.cmpf oge, %[[V16]], %[[V4:.*]] : f32
// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (f32) {
// CHECK-NEXT: fir.store %c1_i32 to %[[V0:.*]] : !fir.ref<i32>
// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %2#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<f32>
// CHECK-NEXT: %[[V21:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (f32) {
// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %{{.}} (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
// CHECK-NEXT: fir.result %[[V20]] : f32
// CHECK-NEXT: } else {
// CHECK-NEXT: fir.result %arg4 : f32
// CHECK-NEXT: }
// CHECK-NEXT: fir.result %[[V22]] : f32
// CHECK-NEXT: } else {
// CHECK-NEXT: fir.result %arg4 : f32
// CHECK-NEXT: }
// CHECK-NEXT: fir.result %[[V18]] : f32
// CHECK-NEXT: }