Skip to content

[Flang] Add Maxloc to fir simplify intrinsics pass #75463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 47 additions & 36 deletions flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class SimplifyIntrinsicsPass
void simplifyLogicalDim1Reduction(fir::CallOp call,
const fir::KindMapping &kindMap,
GenReductionBodyTy genBodyFunc);
void simplifyMinlocReduction(fir::CallOp call,
const fir::KindMapping &kindMap);
void simplifyMinMaxlocReduction(fir::CallOp call,
const fir::KindMapping &kindMap, bool isMax);
void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
GenReductionBodyTy genBodyFunc,
fir::FirOpBuilder &builder,
Expand Down Expand Up @@ -353,16 +353,15 @@ genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
// Return the reduction value from the function.
builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
}
using MinlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
using MinMaxlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
mlir::Value, llvm::SmallVector<mlir::Value, Fortran::common::maxRank> &)>;

static void
genMinlocReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
InitValGeneratorTy initVal,
MinlocBodyOpGeneratorTy genBody, unsigned rank,
mlir::Type elementType, mlir::Location loc, bool hasMask,
mlir::Type maskElemType, mlir::Value resultArr) {
static void genMinMaxlocReductionLoop(
fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
InitValGeneratorTy initVal, MinMaxlocBodyOpGeneratorTy genBody,
unsigned rank, mlir::Type elementType, mlir::Location loc, bool hasMask,
mlir::Type maskElemType, mlir::Value resultArr) {

mlir::IndexType idxTy = builder.getIndexType();

Expand Down Expand Up @@ -751,21 +750,24 @@ static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
{boxRefType, boxType, boxType}, {});
}

static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp, unsigned rank,
int maskRank, mlir::Type elementType,
mlir::Type maskElemType,
mlir::Type resultElemTy) {
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp, bool isMax,
unsigned rank, int maskRank,
mlir::Type elementType,
mlir::Type maskElemType,
mlir::Type resultElemTy) {
auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/false));
loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
}
unsigned bits = elementType.getIntOrFloatBitWidth();
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, elementType, maxInt);
int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
: llvm::APInt::getSignedMaxValue(bits))
.getSExtValue();
return builder.createIntegerConstant(loc, elementType, initValue);
};

mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
Expand Down Expand Up @@ -797,18 +799,24 @@ static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
}

auto genBodyOp =
[&rank, &resultArr](
fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
mlir::Value elem1, mlir::Value elem2,
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
[&rank, &resultArr,
isMax](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType, mlir::Value elem1, mlir::Value elem2,
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
-> mlir::Value {
mlir::Value cmp;
if (elementType.isa<mlir::FloatType>()) {
cmp = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OLT, elem1, elem2);
loc,
isMax ? mlir::arith::CmpFPredicate::OGT
: mlir::arith::CmpFPredicate::OLT,
elem1, elem2);
} else if (elementType.isa<mlir::IntegerType>()) {
cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, elem1, elem2);
loc,
isMax ? mlir::arith::CmpIPredicate::sgt
: mlir::arith::CmpIPredicate::slt,
elem1, elem2);
} else {
llvm_unreachable("unsupported type");
}
Expand Down Expand Up @@ -875,9 +883,8 @@ static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
// bit of a hack - maskRank is set to -1 for absent mask arg, so don't
// generate high level mask or element by element mask.
bool hasMask = maskRank > 0;

genMinlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType,
loc, hasMask, maskElemType, resultArr);
genMinMaxlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType,
loc, hasMask, maskElemType, resultArr);
}

/// Generate function type for the simplified version of RTNAME(DotProduct)
Expand Down Expand Up @@ -1150,8 +1157,8 @@ void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
intElementType);
}

void SimplifyIntrinsicsPass::simplifyMinlocReduction(
fir::CallOp call, const fir::KindMapping &kindMap) {
void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {

mlir::Operation::operand_range args = call.getArgs();

Expand Down Expand Up @@ -1217,11 +1224,11 @@ void SimplifyIntrinsicsPass::simplifyMinlocReduction(
auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
return genRuntimeMinlocType(builder, rank);
};
auto bodyGenerator = [rank, maskRank, inputType, logicalElemType,
outType](fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp) {
genRuntimeMinlocBody(builder, funcOp, rank, maskRank, inputType,
logicalElemType, outType);
auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
isMax](fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp) {
genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
logicalElemType, outType);
};

mlir::func::FuncOp newFunc =
Expand Down Expand Up @@ -1367,7 +1374,11 @@ void SimplifyIntrinsicsPass::runOnOperation() {
return;
}
if (funcName.starts_with(RTNAME_STRING(Minloc))) {
simplifyMinlocReduction(call, kindMap);
simplifyMinMaxlocReduction(call, kindMap, false);
return;
}
if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
simplifyMinMaxlocReduction(call, kindMap, true);
return;
}
}
Expand Down
Loading