Skip to content

Commit b19c1e9

Browse files
committed
[Flang] Add Maxloc to fir simplify intrinsics pass
This takes the code from D144103 and extends it to maxloc, to allow the simplifyMinMaxlocReduction method to work with both min and max intrinsics by switching condition and limit/initial value.
1 parent 3546f4d commit b19c1e9

File tree

2 files changed

+292
-34
lines changed

2 files changed

+292
-34
lines changed

flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ class SimplifyIntrinsicsPass
9999
void simplifyLogicalDim1Reduction(fir::CallOp call,
100100
const fir::KindMapping &kindMap,
101101
GenReductionBodyTy genBodyFunc);
102-
void simplifyMinlocReduction(fir::CallOp call,
103-
const fir::KindMapping &kindMap);
102+
void simplifyMinMaxlocReduction(fir::CallOp call,
103+
const fir::KindMapping &kindMap, bool isMax);
104104
void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
105105
GenReductionBodyTy genBodyFunc,
106106
fir::FirOpBuilder &builder,
@@ -357,12 +357,11 @@ using MinlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
357357
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
358358
mlir::Value, llvm::SmallVector<mlir::Value, Fortran::common::maxRank> &)>;
359359

360-
static void
361-
genMinlocReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
362-
InitValGeneratorTy initVal,
363-
MinlocBodyOpGeneratorTy genBody, unsigned rank,
364-
mlir::Type elementType, mlir::Location loc, bool hasMask,
365-
mlir::Type maskElemType, mlir::Value resultArr) {
360+
static void genMinMaxlocReductionLoop(
361+
fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
362+
InitValGeneratorTy initVal, MinlocBodyOpGeneratorTy genBody, unsigned rank,
363+
mlir::Type elementType, mlir::Location loc, bool hasMask,
364+
mlir::Type maskElemType, mlir::Value resultArr) {
366365

367366
mlir::IndexType idxTy = builder.getIndexType();
368367

@@ -751,20 +750,23 @@ static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
751750
{boxRefType, boxType, boxType}, {});
752751
}
753752

754-
static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
755-
mlir::func::FuncOp &funcOp, unsigned rank,
756-
int maskRank, mlir::Type elementType,
757-
mlir::Type maskElemType,
758-
mlir::Type resultElemTy) {
759-
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
760-
mlir::Type elementType) {
753+
static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
754+
mlir::func::FuncOp &funcOp, bool isMax,
755+
unsigned rank, int maskRank,
756+
mlir::Type elementType,
757+
mlir::Type maskElemType,
758+
mlir::Type resultElemTy) {
759+
auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
760+
mlir::Type elementType) {
761761
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
762762
const llvm::fltSemantics &sem = ty.getFloatSemantics();
763763
return builder.createRealConstant(
764-
loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/false));
764+
loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
765765
}
766766
unsigned bits = elementType.getIntOrFloatBitWidth();
767-
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
767+
int64_t maxInt = (isMax ? llvm::APInt::getSignedMinValue(bits)
768+
: llvm::APInt::getSignedMaxValue(bits))
769+
.getSExtValue();
768770
return builder.createIntegerConstant(loc, elementType, maxInt);
769771
};
770772

@@ -797,18 +799,24 @@ static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
797799
}
798800

799801
auto genBodyOp =
800-
[&rank, &resultArr](
801-
fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
802-
mlir::Value elem1, mlir::Value elem2,
803-
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
802+
[&rank, &resultArr,
803+
isMax](fir::FirOpBuilder builder, mlir::Location loc,
804+
mlir::Type elementType, mlir::Value elem1, mlir::Value elem2,
805+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
804806
-> mlir::Value {
805807
mlir::Value cmp;
806808
if (elementType.isa<mlir::FloatType>()) {
807809
cmp = builder.create<mlir::arith::CmpFOp>(
808-
loc, mlir::arith::CmpFPredicate::OLT, elem1, elem2);
810+
loc,
811+
isMax ? mlir::arith::CmpFPredicate::OGT
812+
: mlir::arith::CmpFPredicate::OLT,
813+
elem1, elem2);
809814
} else if (elementType.isa<mlir::IntegerType>()) {
810815
cmp = builder.create<mlir::arith::CmpIOp>(
811-
loc, mlir::arith::CmpIPredicate::slt, elem1, elem2);
816+
loc,
817+
isMax ? mlir::arith::CmpIPredicate::sgt
818+
: mlir::arith::CmpIPredicate::slt,
819+
elem1, elem2);
812820
} else {
813821
llvm_unreachable("unsupported type");
814822
}
@@ -875,9 +883,8 @@ static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
875883
// bit of a hack - maskRank is set to -1 for absent mask arg, so don't
876884
// generate high level mask or element by element mask.
877885
bool hasMask = maskRank > 0;
878-
879-
genMinlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType,
880-
loc, hasMask, maskElemType, resultArr);
886+
genMinMaxlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType,
887+
loc, hasMask, maskElemType, resultArr);
881888
}
882889

883890
/// Generate function type for the simplified version of RTNAME(DotProduct)
@@ -1150,8 +1157,8 @@ void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
11501157
intElementType);
11511158
}
11521159

1153-
void SimplifyIntrinsicsPass::simplifyMinlocReduction(
1154-
fir::CallOp call, const fir::KindMapping &kindMap) {
1160+
void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1161+
fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {
11551162

11561163
mlir::Operation::operand_range args = call.getArgs();
11571164

@@ -1217,11 +1224,11 @@ void SimplifyIntrinsicsPass::simplifyMinlocReduction(
12171224
auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
12181225
return genRuntimeMinlocType(builder, rank);
12191226
};
1220-
auto bodyGenerator = [rank, maskRank, inputType, logicalElemType,
1221-
outType](fir::FirOpBuilder &builder,
1222-
mlir::func::FuncOp &funcOp) {
1223-
genRuntimeMinlocBody(builder, funcOp, rank, maskRank, inputType,
1224-
logicalElemType, outType);
1227+
auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1228+
isMax](fir::FirOpBuilder &builder,
1229+
mlir::func::FuncOp &funcOp) {
1230+
genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
1231+
logicalElemType, outType);
12251232
};
12261233

12271234
mlir::func::FuncOp newFunc =
@@ -1367,7 +1374,11 @@ void SimplifyIntrinsicsPass::runOnOperation() {
13671374
return;
13681375
}
13691376
if (funcName.starts_with(RTNAME_STRING(Minloc))) {
1370-
simplifyMinlocReduction(call, kindMap);
1377+
simplifyMinMaxlocReduction(call, kindMap, false);
1378+
return;
1379+
}
1380+
if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
1381+
simplifyMinMaxlocReduction(call, kindMap, true);
13711382
return;
13721383
}
13731384
}

0 commit comments

Comments
 (0)