Skip to content

Commit a216115

Browse files
authored
[Flang] Add a HLFIR Minloc intrinsic (#74436)
The adds a hlfir minloc intrinsic, similar to the minval intrinsic already added, to help in the lowering of minloc. The idea is to later add maxloc too, and from there add a simplification for producing minloc with inlined elemental and hopefully less temporaries.
1 parent dd85e67 commit a216115

File tree

10 files changed

+1263
-61
lines changed

10 files changed

+1263
-61
lines changed

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,32 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
458458
let hasVerifier = 1;
459459
}
460460

461+
def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
462+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
463+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
464+
let summary = "MINLOC transformational intrinsic";
465+
let description = [{
466+
Minlocs of an array.
467+
}];
468+
469+
let arguments = (ins
470+
AnyFortranArrayObject:$array,
471+
Optional<AnyIntegerType>:$dim,
472+
Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
473+
Optional<Type<AnyLogicalLike.predicate>>:$back,
474+
DefaultValuedAttr<Arith_FastMathAttr,
475+
"::mlir::arith::FastMathFlags::none">:$fastmath
476+
);
477+
478+
let results = (outs AnyFortranValue);
479+
480+
let assemblyFormat = [{
481+
$array (`dim` $dim^)? (`mask` $mask^)? (`back` $back^)? attr-dict `:` functional-type(operands, results)
482+
}];
483+
484+
let hasVerifier = 1;
485+
}
486+
461487
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
462488
DeclareOpInterfaceMethods<ArithFastMathInterface>,
463489
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

flang/lib/Lower/HlfirIntrinsics.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>;
9393
using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
9494
using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
9595

96+
template <typename OP>
97+
class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic {
98+
public:
99+
using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
100+
101+
protected:
102+
mlir::Value
103+
lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
104+
const fir::IntrinsicArgumentLoweringRules *argLowering,
105+
mlir::Type stmtResultType) override;
106+
};
107+
using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>;
108+
96109
template <typename OP>
97110
class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
98111
public:
@@ -180,6 +193,31 @@ mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
180193
return boxOrAbsent;
181194
}
182195

196+
static mlir::Value loadOptionalValue(
197+
mlir::Location loc, fir::FirOpBuilder &builder,
198+
const std::optional<Fortran::lower::PreparedActualArgument> &arg,
199+
hlfir::Entity actual) {
200+
if (!arg->handleDynamicOptional())
201+
return hlfir::loadTrivialScalar(loc, builder, actual);
202+
203+
mlir::Value isPresent = arg->getIsPresent();
204+
mlir::Type eleType = hlfir::getFortranElementType(actual.getType());
205+
return builder
206+
.genIfOp(loc, {eleType}, isPresent,
207+
/*withElseRegion=*/true)
208+
.genThen([&]() {
209+
assert(actual.isScalar() && fir::isa_trivial(eleType) &&
210+
"must be a numerical or logical scalar");
211+
hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual);
212+
builder.create<fir::ResultOp>(loc, val);
213+
})
214+
.genElse([&]() {
215+
mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType);
216+
builder.create<fir::ResultOp>(loc, zero);
217+
})
218+
.getResults()[0];
219+
}
220+
183221
llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
184222
const Fortran::lower::PreparedActualArguments &loweredActuals,
185223
const fir::IntrinsicArgumentLoweringRules *argLowering) {
@@ -206,6 +244,9 @@ llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
206244
else if (!argRules.handleDynamicOptional &&
207245
argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
208246
valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
247+
else if (argRules.handleDynamicOptional &&
248+
argRules.lowerAs == fir::LowerIntrinsicArgAs::Value)
249+
valArg = loadOptionalValue(loc, builder, arg, actual);
209250
else if (argRules.handleDynamicOptional)
210251
TODO(loc, "hlfir transformational intrinsic dynamically optional "
211252
"argument without box lowering");
@@ -260,6 +301,27 @@ mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
260301
return op;
261302
}
262303

304+
template <typename OP>
305+
mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl(
306+
const Fortran::lower::PreparedActualArguments &loweredActuals,
307+
const fir::IntrinsicArgumentLoweringRules *argLowering,
308+
mlir::Type stmtResultType) {
309+
auto operands = getOperandVector(loweredActuals, argLowering);
310+
mlir::Value array = operands[0];
311+
mlir::Value dim = operands[1];
312+
mlir::Value mask = operands[2];
313+
mlir::Value back = operands[4];
314+
// dim, mask and back can be NULL if these arguments are not given.
315+
if (dim)
316+
dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
317+
if (back)
318+
back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back});
319+
320+
mlir::Type resultTy = computeResultType(array, stmtResultType);
321+
322+
return createOp<OP>(resultTy, array, dim, mask, back);
323+
}
324+
263325
template <typename OP>
264326
mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
265327
const Fortran::lower::PreparedActualArguments &loweredActuals,
@@ -364,6 +426,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
364426
if (name == "minval")
365427
return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering,
366428
stmtResultType);
429+
if (name == "minloc")
430+
return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering,
431+
stmtResultType);
367432
if (mlir::isa<fir::CharacterType>(stmtResultType)) {
368433
if (name == "min")
369434
return HlfirCharExtremumLowering{builder, loc,

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -661,20 +661,14 @@ void hlfir::ConcatOp::getEffects(
661661

662662
template <typename NumericalReductionOp>
663663
static mlir::LogicalResult
664-
verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
665-
mlir::Operation *op = reductionOp->getOperation();
666-
667-
auto results = op->getResultTypes();
668-
assert(results.size() == 1);
669-
664+
verifyArrayAndMaskForReductionOp(NumericalReductionOp reductionOp) {
670665
mlir::Value array = reductionOp->getArray();
671666
mlir::Value dim = reductionOp->getDim();
672667
mlir::Value mask = reductionOp->getMask();
673668

674669
fir::SequenceType arrayTy =
675670
hlfir::getFortranElementOrSequenceType(array.getType())
676671
.cast<fir::SequenceType>();
677-
mlir::Type numTy = arrayTy.getEleTy();
678672
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
679673

680674
if (mask) {
@@ -701,6 +695,27 @@ verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
701695
}
702696
}
703697
}
698+
return mlir::success();
699+
}
700+
701+
template <typename NumericalReductionOp>
702+
static mlir::LogicalResult
703+
verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
704+
mlir::Operation *op = reductionOp->getOperation();
705+
auto results = op->getResultTypes();
706+
assert(results.size() == 1);
707+
708+
auto res = verifyArrayAndMaskForReductionOp(reductionOp);
709+
if (failed(res))
710+
return res;
711+
712+
mlir::Value array = reductionOp->getArray();
713+
mlir::Value dim = reductionOp->getDim();
714+
fir::SequenceType arrayTy =
715+
hlfir::getFortranElementOrSequenceType(array.getType())
716+
.cast<fir::SequenceType>();
717+
mlir::Type numTy = arrayTy.getEleTy();
718+
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
704719

705720
mlir::Type resultType = results[0];
706721
if (hlfir::isFortranScalarNumericalType(resultType)) {
@@ -757,45 +772,21 @@ template <typename CharacterReductionOp>
757772
static mlir::LogicalResult
758773
verifyCharacterReductionOp(CharacterReductionOp reductionOp) {
759774
mlir::Operation *op = reductionOp->getOperation();
760-
761775
auto results = op->getResultTypes();
762776
assert(results.size() == 1);
763777

778+
auto res = verifyArrayAndMaskForReductionOp(reductionOp);
779+
if (failed(res))
780+
return res;
781+
764782
mlir::Value array = reductionOp->getArray();
765783
mlir::Value dim = reductionOp->getDim();
766-
mlir::Value mask = reductionOp->getMask();
767-
768784
fir::SequenceType arrayTy =
769785
hlfir::getFortranElementOrSequenceType(array.getType())
770786
.cast<fir::SequenceType>();
771787
mlir::Type numTy = arrayTy.getEleTy();
772788
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
773789

774-
if (mask) {
775-
fir::SequenceType maskSeq =
776-
hlfir::getFortranElementOrSequenceType(mask.getType())
777-
.dyn_cast<fir::SequenceType>();
778-
llvm::ArrayRef<int64_t> maskShape;
779-
780-
if (maskSeq)
781-
maskShape = maskSeq.getShape();
782-
783-
if (!maskShape.empty()) {
784-
if (maskShape.size() != arrayShape.size())
785-
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
786-
static_assert(fir::SequenceType::getUnknownExtent() ==
787-
hlfir::ExprType::getUnknownExtent());
788-
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
789-
for (std::size_t i = 0; i < arrayShape.size(); ++i) {
790-
int64_t arrayExtent = arrayShape[i];
791-
int64_t maskExtent = maskShape[i];
792-
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
793-
(maskExtent != unknownExtent))
794-
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
795-
}
796-
}
797-
}
798-
799790
auto resultExpr = results[0].cast<hlfir::ExprType>();
800791
mlir::Type resultType = resultExpr.getEleTy();
801792
assert(mlir::isa<fir::CharacterType>(resultType) &&
@@ -870,6 +861,58 @@ void hlfir::MinvalOp::getEffects(
870861
getIntrinsicEffects(getOperation(), effects);
871862
}
872863

864+
//===----------------------------------------------------------------------===//
865+
// MinlocOp
866+
//===----------------------------------------------------------------------===//
867+
868+
mlir::LogicalResult hlfir::MinlocOp::verify() {
869+
mlir::Operation *op = getOperation();
870+
auto results = op->getResultTypes();
871+
assert(results.size() == 1);
872+
873+
auto res = verifyArrayAndMaskForReductionOp(this);
874+
if (failed(res))
875+
return res;
876+
877+
mlir::Value array = getArray();
878+
mlir::Value dim = getDim();
879+
fir::SequenceType arrayTy =
880+
hlfir::getFortranElementOrSequenceType(array.getType())
881+
.cast<fir::SequenceType>();
882+
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
883+
884+
mlir::Type resultType = results[0];
885+
if (dim && arrayShape.size() == 1) {
886+
if (!fir::isa_integer(resultType))
887+
return emitOpError("result must be scalar integer");
888+
} else if (auto resultExpr =
889+
mlir::dyn_cast_or_null<hlfir::ExprType>(resultType)) {
890+
if (!resultExpr.isArray())
891+
return emitOpError("result must be an array");
892+
893+
if (!fir::isa_integer(resultExpr.getEleTy()))
894+
return emitOpError("result must have integer elements");
895+
896+
llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
897+
// With dim the result has rank n-1
898+
if (dim && resultShape.size() != (arrayShape.size() - 1))
899+
return emitOpError("result rank must be one less than ARRAY");
900+
// With dim the result has rank n
901+
if (!dim && resultShape.size() != 1)
902+
return emitOpError("result rank must be 1");
903+
} else {
904+
return emitOpError("result must be of numerical expr type");
905+
}
906+
return mlir::success();
907+
}
908+
909+
void hlfir::MinlocOp::getEffects(
910+
llvm::SmallVectorImpl<
911+
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
912+
&effects) {
913+
getIntrinsicEffects(getOperation(), effects);
914+
}
915+
873916
//===----------------------------------------------------------------------===//
874917
// SetLengthOp
875918
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,23 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
201201
return lowerArguments(operation, inArgs, rewriter, argLowering);
202202
};
203203

204+
auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
205+
mlir::PatternRewriter &rewriter, std::string opName,
206+
fir::FirOpBuilder builder) const {
207+
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
208+
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
209+
inArgs.push_back({operation.getDim(), i32});
210+
inArgs.push_back({operation.getMask(), logicalType});
211+
mlir::Type T = hlfir::getFortranElementType(operation.getType());
212+
unsigned width = T.cast<mlir::IntegerType>().getWidth();
213+
mlir::Value kind =
214+
builder.createIntegerConstant(operation->getLoc(), i32, width / 8);
215+
inArgs.push_back({kind, i32});
216+
inArgs.push_back({operation.getBack(), i32});
217+
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
218+
return lowerArguments(operation, inArgs, rewriter, argLowering);
219+
};
220+
204221
auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
205222
mlir::PatternRewriter &rewriter,
206223
std::string opName) const {
@@ -224,6 +241,8 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
224241
opName = "maxval";
225242
} else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
226243
opName = "minval";
244+
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
245+
opName = "minloc";
227246
} else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
228247
opName = "any";
229248
} else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
@@ -246,6 +265,9 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
246265
std::is_same_v<OP, hlfir::MaxvalOp> ||
247266
std::is_same_v<OP, hlfir::MinvalOp>) {
248267
args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
268+
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
269+
args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
270+
builder);
249271
} else {
250272
args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
251273
}
@@ -269,6 +291,8 @@ using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
269291

270292
using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;
271293

294+
using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;
295+
272296
using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
273297

274298
using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
@@ -441,20 +465,20 @@ class LowerHLFIRIntrinsics
441465
mlir::ModuleOp module = this->getOperation();
442466
mlir::MLIRContext *context = &getContext();
443467
mlir::RewritePatternSet patterns(context);
444-
patterns
445-
.insert<MatmulOpConversion, MatmulTransposeOpConversion,
446-
AllOpConversion, AnyOpConversion, SumOpConversion,
447-
ProductOpConversion, TransposeOpConversion, CountOpConversion,
448-
DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion>(
449-
context);
468+
patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
469+
AllOpConversion, AnyOpConversion, SumOpConversion,
470+
ProductOpConversion, TransposeOpConversion,
471+
CountOpConversion, DotProductOpConversion,
472+
MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion>(
473+
context);
450474
mlir::ConversionTarget target(*context);
451475
target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
452476
mlir::func::FuncDialect, fir::FIROpsDialect,
453477
hlfir::hlfirDialect>();
454478
target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
455479
hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp,
456480
hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp,
457-
hlfir::MaxvalOp, hlfir::MinvalOp>();
481+
hlfir::MaxvalOp, hlfir::MinvalOp, hlfir::MinlocOp>();
458482
target.markUnknownOpDynamicallyLegal(
459483
[](mlir::Operation *) { return true; });
460484
if (mlir::failed(

0 commit comments

Comments
 (0)