Skip to content

[Flang] Add a HLFIR Minloc intrinsic #74436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,32 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
let hasVerifier = 1;
}

def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "MINLOC transformational intrinsic";
let description = [{
Minlocs of an array.
}];

let arguments = (ins
AnyFortranArrayObject:$array,
Optional<AnyIntegerType>:$dim,
Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
Optional<Type<AnyLogicalLike.predicate>>:$back,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath
);

let results = (outs AnyFortranValue);

let assemblyFormat = [{
$array (`dim` $dim^)? (`mask` $mask^)? (`back` $back^)? attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Expand Down
65 changes: 65 additions & 0 deletions flang/lib/Lower/HlfirIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>;
using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;

template <typename OP>
class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic {
public:
using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;

protected:
mlir::Value
lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
const fir::IntrinsicArgumentLoweringRules *argLowering,
mlir::Type stmtResultType) override;
};
using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>;

template <typename OP>
class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
public:
Expand Down Expand Up @@ -180,6 +193,31 @@ mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
return boxOrAbsent;
}

static mlir::Value loadOptionalValue(
mlir::Location loc, fir::FirOpBuilder &builder,
const std::optional<Fortran::lower::PreparedActualArgument> &arg,
hlfir::Entity actual) {
if (!arg->handleDynamicOptional())
return hlfir::loadTrivialScalar(loc, builder, actual);

mlir::Value isPresent = arg->getIsPresent();
mlir::Type eleType = hlfir::getFortranElementType(actual.getType());
return builder
.genIfOp(loc, {eleType}, isPresent,
/*withElseRegion=*/true)
.genThen([&]() {
assert(actual.isScalar() && fir::isa_trivial(eleType) &&
"must be a numerical or logical scalar");
hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual);
builder.create<fir::ResultOp>(loc, val);
})
.genElse([&]() {
mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType);
builder.create<fir::ResultOp>(loc, zero);
})
.getResults()[0];
}

llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
const Fortran::lower::PreparedActualArguments &loweredActuals,
const fir::IntrinsicArgumentLoweringRules *argLowering) {
Expand All @@ -206,6 +244,9 @@ llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
else if (!argRules.handleDynamicOptional &&
argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
else if (argRules.handleDynamicOptional &&
argRules.lowerAs == fir::LowerIntrinsicArgAs::Value)
valArg = loadOptionalValue(loc, builder, arg, actual);
else if (argRules.handleDynamicOptional)
TODO(loc, "hlfir transformational intrinsic dynamically optional "
"argument without box lowering");
Expand Down Expand Up @@ -260,6 +301,27 @@ mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
return op;
}

template <typename OP>
mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl(
const Fortran::lower::PreparedActualArguments &loweredActuals,
const fir::IntrinsicArgumentLoweringRules *argLowering,
mlir::Type stmtResultType) {
auto operands = getOperandVector(loweredActuals, argLowering);
mlir::Value array = operands[0];
mlir::Value dim = operands[1];
mlir::Value mask = operands[2];
mlir::Value back = operands[4];
// dim, mask and back can be NULL if these arguments are not given.
if (dim)
dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
if (back)
back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back});

mlir::Type resultTy = computeResultType(array, stmtResultType);

return createOp<OP>(resultTy, array, dim, mask, back);
}

template <typename OP>
mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
const Fortran::lower::PreparedActualArguments &loweredActuals,
Expand Down Expand Up @@ -364,6 +426,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
if (name == "minval")
return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering,
stmtResultType);
if (name == "minloc")
return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering,
stmtResultType);
if (mlir::isa<fir::CharacterType>(stmtResultType)) {
if (name == "min")
return HlfirCharExtremumLowering{builder, loc,
Expand Down
113 changes: 78 additions & 35 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,20 +661,14 @@ void hlfir::ConcatOp::getEffects(

template <typename NumericalReductionOp>
static mlir::LogicalResult
verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
mlir::Operation *op = reductionOp->getOperation();

auto results = op->getResultTypes();
assert(results.size() == 1);

verifyArrayAndMaskForReductionOp(NumericalReductionOp reductionOp) {
mlir::Value array = reductionOp->getArray();
mlir::Value dim = reductionOp->getDim();
mlir::Value mask = reductionOp->getMask();

fir::SequenceType arrayTy =
hlfir::getFortranElementOrSequenceType(array.getType())
.cast<fir::SequenceType>();
mlir::Type numTy = arrayTy.getEleTy();
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();

if (mask) {
Expand All @@ -701,6 +695,27 @@ verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
}
}
}
return mlir::success();
}

template <typename NumericalReductionOp>
static mlir::LogicalResult
verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
mlir::Operation *op = reductionOp->getOperation();
auto results = op->getResultTypes();
assert(results.size() == 1);

auto res = verifyArrayAndMaskForReductionOp(reductionOp);
if (failed(res))
return res;

mlir::Value array = reductionOp->getArray();
mlir::Value dim = reductionOp->getDim();
fir::SequenceType arrayTy =
hlfir::getFortranElementOrSequenceType(array.getType())
.cast<fir::SequenceType>();
mlir::Type numTy = arrayTy.getEleTy();
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();

mlir::Type resultType = results[0];
if (hlfir::isFortranScalarNumericalType(resultType)) {
Expand Down Expand Up @@ -757,45 +772,21 @@ template <typename CharacterReductionOp>
static mlir::LogicalResult
verifyCharacterReductionOp(CharacterReductionOp reductionOp) {
mlir::Operation *op = reductionOp->getOperation();

auto results = op->getResultTypes();
assert(results.size() == 1);

auto res = verifyArrayAndMaskForReductionOp(reductionOp);
if (failed(res))
return res;

mlir::Value array = reductionOp->getArray();
mlir::Value dim = reductionOp->getDim();
mlir::Value mask = reductionOp->getMask();

fir::SequenceType arrayTy =
hlfir::getFortranElementOrSequenceType(array.getType())
.cast<fir::SequenceType>();
mlir::Type numTy = arrayTy.getEleTy();
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();

if (mask) {
fir::SequenceType maskSeq =
hlfir::getFortranElementOrSequenceType(mask.getType())
.dyn_cast<fir::SequenceType>();
llvm::ArrayRef<int64_t> maskShape;

if (maskSeq)
maskShape = maskSeq.getShape();

if (!maskShape.empty()) {
if (maskShape.size() != arrayShape.size())
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
static_assert(fir::SequenceType::getUnknownExtent() ==
hlfir::ExprType::getUnknownExtent());
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (std::size_t i = 0; i < arrayShape.size(); ++i) {
int64_t arrayExtent = arrayShape[i];
int64_t maskExtent = maskShape[i];
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
(maskExtent != unknownExtent))
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
}
}
}

auto resultExpr = results[0].cast<hlfir::ExprType>();
mlir::Type resultType = resultExpr.getEleTy();
assert(mlir::isa<fir::CharacterType>(resultType) &&
Expand Down Expand Up @@ -870,6 +861,58 @@ void hlfir::MinvalOp::getEffects(
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// MinlocOp
//===----------------------------------------------------------------------===//

mlir::LogicalResult hlfir::MinlocOp::verify() {
mlir::Operation *op = getOperation();
auto results = op->getResultTypes();
assert(results.size() == 1);

auto res = verifyArrayAndMaskForReductionOp(this);
if (failed(res))
return res;

mlir::Value array = getArray();
mlir::Value dim = getDim();
fir::SequenceType arrayTy =
hlfir::getFortranElementOrSequenceType(array.getType())
.cast<fir::SequenceType>();
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();

mlir::Type resultType = results[0];
if (dim && arrayShape.size() == 1) {
if (!fir::isa_integer(resultType))
return emitOpError("result must be scalar integer");
} else if (auto resultExpr =
mlir::dyn_cast_or_null<hlfir::ExprType>(resultType)) {
if (!resultExpr.isArray())
return emitOpError("result must be an array");

if (!fir::isa_integer(resultExpr.getEleTy()))
return emitOpError("result must have integer elements");

llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
// With dim the result has rank n-1
if (dim && resultShape.size() != (arrayShape.size() - 1))
return emitOpError("result rank must be one less than ARRAY");
// With dim the result has rank n
if (!dim && resultShape.size() != 1)
return emitOpError("result rank must be 1");
} else {
return emitOpError("result must be of numerical expr type");
}
return mlir::success();
}

void hlfir::MinlocOp::getEffects(
llvm::SmallVectorImpl<
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
&effects) {
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// SetLengthOp
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 31 additions & 7 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,23 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
return lowerArguments(operation, inArgs, rewriter, argLowering);
};

auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter, std::string opName,
fir::FirOpBuilder builder) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
inArgs.push_back({operation.getDim(), i32});
inArgs.push_back({operation.getMask(), logicalType});
mlir::Type T = hlfir::getFortranElementType(operation.getType());
unsigned width = T.cast<mlir::IntegerType>().getWidth();
mlir::Value kind =
builder.createIntegerConstant(operation->getLoc(), i32, width / 8);
inArgs.push_back({kind, i32});
inArgs.push_back({operation.getBack(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};

auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
Expand All @@ -224,6 +241,8 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
opName = "maxval";
} else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
opName = "minval";
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
opName = "minloc";
} else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
opName = "any";
} else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
Expand All @@ -246,6 +265,9 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
std::is_same_v<OP, hlfir::MaxvalOp> ||
std::is_same_v<OP, hlfir::MinvalOp>) {
args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
builder);
} else {
args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
}
Expand All @@ -269,6 +291,8 @@ using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;

using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;

using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;

using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;

using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
Expand Down Expand Up @@ -441,20 +465,20 @@ class LowerHLFIRIntrinsics
mlir::ModuleOp module = this->getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns
.insert<MatmulOpConversion, MatmulTransposeOpConversion,
AllOpConversion, AnyOpConversion, SumOpConversion,
ProductOpConversion, TransposeOpConversion, CountOpConversion,
DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion>(
context);
patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
AllOpConversion, AnyOpConversion, SumOpConversion,
ProductOpConversion, TransposeOpConversion,
CountOpConversion, DotProductOpConversion,
MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion>(
context);
mlir::ConversionTarget target(*context);
target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect, fir::FIROpsDialect,
hlfir::hlfirDialect>();
target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp,
hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp,
hlfir::MaxvalOp, hlfir::MinvalOp>();
hlfir::MaxvalOp, hlfir::MinvalOp, hlfir::MinlocOp>();
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
if (mlir::failed(
Expand Down
Loading