Skip to content

[flang] Added hlfir.reshape definition/lowering/codegen. #124226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def IsFortranNumericalArrayObjectPred
def AnyFortranNumericalArrayObject : Type<IsFortranNumericalArrayObjectPred,
"any array-like object containing a numerical type">;

def AnyFortranNumericalArrayEntity
: Type<And<[AnyFortranNumericalArrayObject.predicate,
AnyFortranEntity.predicate]>,
"any array-like entity containing a numerical type">;

def IsFortranNumericalOrLogicalArrayObjectPred
: CPred<"::hlfir::isFortranNumericalOrLogicalArrayObject($_self)">;
def AnyFortranNumericalOrLogicalArrayObject : Type<IsFortranNumericalOrLogicalArrayObjectPred,
Expand All @@ -135,6 +140,10 @@ def IsFortranArrayObjectPred
def AnyFortranArrayObject : Type<IsFortranArrayObjectPred,
"any array-like object">;

def AnyFortranArrayEntity
: Type<And<[AnyFortranArrayObject.predicate, AnyFortranEntity.predicate]>,
"any array-like entity">;

def IsPassByRefOrIntegerTypePred
: CPred<"::hlfir::isPassByRefOrIntegerType($_self)">;
def AnyPassByRefOrIntegerType : Type<IsPassByRefOrIntegerTypePred,
Expand Down
26 changes: 26 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,32 @@ def hlfir_CShiftOp
let hasVerifier = 1;
}

def hlfir_ReshapeOp
: hlfir_Op<
"reshape", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "RESHAPE transformational intrinsic";
let description = [{
Reshapes an ARRAY to correspond to the given SHAPE.
If PAD is specified the new array may be padded with elements
from PAD array.
If ORDER is specified the new array may be permuted accordingly.
}];

let arguments = (ins AnyFortranArrayEntity:$array,
AnyFortranNumericalArrayEntity:$shape,
Optional<AnyFortranArrayEntity>:$pad,
Optional<AnyFortranNumericalArrayEntity>:$order);

let results = (outs hlfir_ExprType);

let assemblyFormat = [{
$array $shape (`pad` $pad^)? (`order` $order^)? attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

// An allocation effect is needed because the value produced by the associate
// is "deallocated" by hlfir.end_associate (the end_associate must not be
// removed, and there must be only one hlfir.end_associate).
Expand Down
25 changes: 25 additions & 0 deletions flang/lib/Lower/HlfirIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,17 @@ class HlfirCShiftLowering : public HlfirTransformationalIntrinsic {
mlir::Type stmtResultType) override;
};

class HlfirReshapeLowering : public HlfirTransformationalIntrinsic {
public:
using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;

protected:
mlir::Value
lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
const fir::IntrinsicArgumentLoweringRules *argLowering,
mlir::Type stmtResultType) override;
};

} // namespace

mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
Expand Down Expand Up @@ -419,6 +430,17 @@ mlir::Value HlfirCShiftLowering::lowerImpl(
return createOp<hlfir::CShiftOp>(resultType, operands);
}

mlir::Value HlfirReshapeLowering::lowerImpl(
const Fortran::lower::PreparedActualArguments &loweredActuals,
const fir::IntrinsicArgumentLoweringRules *argLowering,
mlir::Type stmtResultType) {
auto operands = getOperandVector(loweredActuals, argLowering);
assert(operands.size() == 4);
mlir::Type resultType = computeResultType(operands[0], stmtResultType);
return createOp<hlfir::ReshapeOp>(resultType, operands[0], operands[1],
operands[2], operands[3]);
}

std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
const Fortran::lower::PreparedActualArguments &loweredActuals,
Expand Down Expand Up @@ -467,6 +489,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
if (name == "cshift")
return HlfirCShiftLowering{builder, loc}.lower(loweredActuals, argLowering,
stmtResultType);
if (name == "reshape")
return HlfirReshapeLowering{builder, loc}.lower(loweredActuals, argLowering,
stmtResultType);
if (mlir::isa<fir::CharacterType>(stmtResultType)) {
if (name == "min")
return HlfirCharExtremumLowering{builder, loc,
Expand Down
111 changes: 94 additions & 17 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,33 @@ getIntrinsicEffects(mlir::Operation *self,
}
}

/// Verification helper for checking if two types are the same.
/// Set \p allowCharacterLenMismatch to true, if character types
/// of different known lengths should be treated as the same.
template <typename Op>
static llvm::LogicalResult areMatchingTypes(Op &op, mlir::Type type1,
mlir::Type type2,
bool allowCharacterLenMismatch) {
if (auto charType1 = mlir::dyn_cast<fir::CharacterType>(type1))
if (auto charType2 = mlir::dyn_cast<fir::CharacterType>(type2)) {
// Character kinds must match.
if (charType1.getFKind() != charType2.getFKind())
return op.emitOpError("character KIND mismatch");

// Constant propagation can result in mismatching lengths
// in the dead code, but we should not fail on this.
if (!allowCharacterLenMismatch)
if (charType1.getLen() != fir::CharacterType::unknownLen() &&
charType2.getLen() != fir::CharacterType::unknownLen() &&
charType1.getLen() != charType2.getLen())
return op.emitOpError("character LEN mismatch");

return mlir::success();
}

return type1 == type2 ? mlir::success() : mlir::failure();
}

//===----------------------------------------------------------------------===//
// DeclareOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1360,23 +1387,12 @@ llvm::LogicalResult hlfir::CShiftOp::verify() {
mlir::Value shift = getShift();
mlir::Type shiftTy = hlfir::getFortranElementOrSequenceType(shift.getType());

if (eleTy != resultEleTy) {
if (mlir::isa<fir::CharacterType>(eleTy) &&
mlir::isa<fir::CharacterType>(resultEleTy)) {
auto eleCharTy = mlir::cast<fir::CharacterType>(eleTy);
auto resultCharTy = mlir::cast<fir::CharacterType>(resultEleTy);
if (eleCharTy.getFKind() != resultCharTy.getFKind())
return emitOpError("kind mismatch between input and output arrays");
if (eleCharTy.getLen() != fir::CharacterType::unknownLen() &&
resultCharTy.getLen() != fir::CharacterType::unknownLen() &&
eleCharTy.getLen() != resultCharTy.getLen())
return emitOpError(
"character LEN mismatch between input and output arrays");
} else {
return emitOpError(
"input and output arrays should have the same element type");
}
}
// TODO: turn allowCharacterLenMismatch into true.
if (auto match = areMatchingTypes(*this, eleTy, resultEleTy,
/*allowCharacterLenMismatch=*/false);
match.failed())
return emitOpError(
"input and output arrays should have the same element type");

if (arrayRank != resultRank)
return emitOpError("input and output arrays should have the same rank");
Expand Down Expand Up @@ -1444,6 +1460,67 @@ void hlfir::CShiftOp::getEffects(
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//

llvm::LogicalResult hlfir::ReshapeOp::verify() {
auto results = getOperation()->getResultTypes();
assert(results.size() == 1);
hlfir::ExprType resultType = mlir::cast<hlfir::ExprType>(results[0]);
mlir::Value array = getArray();
auto arrayType = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(array.getType()));
if (auto match = areMatchingTypes(
*this, hlfir::getFortranElementType(resultType),
arrayType.getElementType(), /*allowCharacterLenMismatch=*/true);
match.failed())
return emitOpError("ARRAY and the result must have the same element type");
if (hlfir::isPolymorphicType(resultType) !=
hlfir::isPolymorphicType(array.getType()))
return emitOpError("ARRAY must be polymorphic iff result is polymorphic");

mlir::Value shape = getShape();
auto shapeArrayType = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(shape.getType()));
if (shapeArrayType.getDimension() != 1)
return emitOpError("SHAPE must be an array of rank 1");
if (!mlir::isa<mlir::IntegerType>(shapeArrayType.getElementType()))
return emitOpError("SHAPE must be an integer array");
if (shapeArrayType.hasDynamicExtents())
return emitOpError("SHAPE must have known size");
if (shapeArrayType.getConstantArraySize() != resultType.getRank())
return emitOpError("SHAPE's extent must match the result rank");

if (mlir::Value pad = getPad()) {
auto padArrayType = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(pad.getType()));
if (auto match = areMatchingTypes(*this, arrayType.getElementType(),
padArrayType.getElementType(),
/*allowCharacterLenMismatch=*/true);
match.failed())
return emitOpError("ARRAY and PAD must be of the same type");
}

if (mlir::Value order = getOrder()) {
auto orderArrayType = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(order.getType()));
if (orderArrayType.getDimension() != 1)
return emitOpError("ORDER must be an array of rank 1");
if (!mlir::isa<mlir::IntegerType>(orderArrayType.getElementType()))
return emitOpError("ORDER must be an integer array");
}

return mlir::success();
}

void hlfir::ReshapeOp::getEffects(
llvm::SmallVectorImpl<
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
&effects) {
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// AssociateOp
//===----------------------------------------------------------------------===//
Expand Down
48 changes: 41 additions & 7 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,20 +494,54 @@ class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> {
}
};

class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> {
using HlfirIntrinsicConversion<hlfir::ReshapeOp>::HlfirIntrinsicConversion;

llvm::LogicalResult
matchAndRewrite(hlfir::ReshapeOp reshape,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
const mlir::Location &loc = reshape->getLoc();

llvm::SmallVector<IntrinsicArgument, 4> inArgs;
mlir::Value array = reshape.getArray();
inArgs.push_back({array, array.getType()});
mlir::Value shape = reshape.getShape();
inArgs.push_back({shape, shape.getType()});
mlir::Type noneType = builder.getNoneType();
mlir::Value pad = reshape.getPad();
inArgs.push_back({pad, pad ? pad.getType() : noneType});
mlir::Value order = reshape.getOrder();
inArgs.push_back({order, order ? order.getType() : noneType});

auto *argLowering = fir::getIntrinsicArgumentLowering("reshape");
llvm::SmallVector<fir::ExtendedValue, 4> args =
lowerArguments(reshape, inArgs, rewriter, argLowering);

mlir::Type scalarResultType =
hlfir::getFortranElementType(reshape.getType());

auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "reshape", scalarResultType, args);

processReturnValue(reshape, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};

class LowerHLFIRIntrinsics
: public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
public:
void runOnOperation() override {
mlir::ModuleOp module = this->getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns
.insert<MatmulOpConversion, MatmulTransposeOpConversion,
AllOpConversion, AnyOpConversion, SumOpConversion,
ProductOpConversion, TransposeOpConversion, CountOpConversion,
DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion,
MinlocOpConversion, MaxlocOpConversion, CShiftOpConversion>(
context);
patterns.insert<
MatmulOpConversion, MatmulTransposeOpConversion, AllOpConversion,
AnyOpConversion, SumOpConversion, ProductOpConversion,
TransposeOpConversion, CountOpConversion, DotProductOpConversion,
MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context);

// While conceptually this pass is performing dialect conversion, we use
// pattern rewrites here instead of dialect conversion because this pass
Expand Down
Loading
Loading