Skip to content

Commit c489108

Browse files
authored
[flang] Added hlfir.reshape definition/lowering/codegen. (#124226)
Lower Fortran RESHAPE intrinsic into hlfir.reshape, and then lower hlfir.reshape into a runtime call. A later patch will add hlfir.reshape inlining as hlfir.elemental.
1 parent 5ece348 commit c489108

File tree

9 files changed

+1001
-26
lines changed

9 files changed

+1001
-26
lines changed

flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def IsFortranNumericalArrayObjectPred
125125
def AnyFortranNumericalArrayObject : Type<IsFortranNumericalArrayObjectPred,
126126
"any array-like object containing a numerical type">;
127127

128+
def AnyFortranNumericalArrayEntity
129+
: Type<And<[AnyFortranNumericalArrayObject.predicate,
130+
AnyFortranEntity.predicate]>,
131+
"any array-like entity containing a numerical type">;
132+
128133
def IsFortranNumericalOrLogicalArrayObjectPred
129134
: CPred<"::hlfir::isFortranNumericalOrLogicalArrayObject($_self)">;
130135
def AnyFortranNumericalOrLogicalArrayObject : Type<IsFortranNumericalOrLogicalArrayObjectPred,
@@ -135,6 +140,10 @@ def IsFortranArrayObjectPred
135140
def AnyFortranArrayObject : Type<IsFortranArrayObjectPred,
136141
"any array-like object">;
137142

143+
def AnyFortranArrayEntity
144+
: Type<And<[AnyFortranArrayObject.predicate, AnyFortranEntity.predicate]>,
145+
"any array-like entity">;
146+
138147
def IsPassByRefOrIntegerTypePred
139148
: CPred<"::hlfir::isPassByRefOrIntegerType($_self)">;
140149
def AnyPassByRefOrIntegerType : Type<IsPassByRefOrIntegerTypePred,

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,32 @@ def hlfir_CShiftOp
720720
let hasVerifier = 1;
721721
}
722722

723+
def hlfir_ReshapeOp
724+
: hlfir_Op<
725+
"reshape", [AttrSizedOperandSegments,
726+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
727+
let summary = "RESHAPE transformational intrinsic";
728+
let description = [{
729+
Reshapes an ARRAY to correspond to the given SHAPE.
730+
If PAD is specified the new array may be padded with elements
731+
from PAD array.
732+
If ORDER is specified the new array may be permuted accordingly.
733+
}];
734+
735+
let arguments = (ins AnyFortranArrayEntity:$array,
736+
AnyFortranNumericalArrayEntity:$shape,
737+
Optional<AnyFortranArrayEntity>:$pad,
738+
Optional<AnyFortranNumericalArrayEntity>:$order);
739+
740+
let results = (outs hlfir_ExprType);
741+
742+
let assemblyFormat = [{
743+
$array $shape (`pad` $pad^)? (`order` $order^)? attr-dict `:` functional-type(operands, results)
744+
}];
745+
746+
let hasVerifier = 1;
747+
}
748+
723749
// An allocation effect is needed because the value produced by the associate
724750
// is "deallocated" by hlfir.end_associate (the end_associate must not be
725751
// removed, and there must be only one hlfir.end_associate).

flang/lib/Lower/HlfirIntrinsics.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,17 @@ class HlfirCShiftLowering : public HlfirTransformationalIntrinsic {
170170
mlir::Type stmtResultType) override;
171171
};
172172

173+
class HlfirReshapeLowering : public HlfirTransformationalIntrinsic {
174+
public:
175+
using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
176+
177+
protected:
178+
mlir::Value
179+
lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
180+
const fir::IntrinsicArgumentLoweringRules *argLowering,
181+
mlir::Type stmtResultType) override;
182+
};
183+
173184
} // namespace
174185

175186
mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
@@ -419,6 +430,17 @@ mlir::Value HlfirCShiftLowering::lowerImpl(
419430
return createOp<hlfir::CShiftOp>(resultType, operands);
420431
}
421432

433+
mlir::Value HlfirReshapeLowering::lowerImpl(
434+
const Fortran::lower::PreparedActualArguments &loweredActuals,
435+
const fir::IntrinsicArgumentLoweringRules *argLowering,
436+
mlir::Type stmtResultType) {
437+
auto operands = getOperandVector(loweredActuals, argLowering);
438+
assert(operands.size() == 4);
439+
mlir::Type resultType = computeResultType(operands[0], stmtResultType);
440+
return createOp<hlfir::ReshapeOp>(resultType, operands[0], operands[1],
441+
operands[2], operands[3]);
442+
}
443+
422444
std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
423445
fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
424446
const Fortran::lower::PreparedActualArguments &loweredActuals,
@@ -467,6 +489,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
467489
if (name == "cshift")
468490
return HlfirCShiftLowering{builder, loc}.lower(loweredActuals, argLowering,
469491
stmtResultType);
492+
if (name == "reshape")
493+
return HlfirReshapeLowering{builder, loc}.lower(loweredActuals, argLowering,
494+
stmtResultType);
470495
if (mlir::isa<fir::CharacterType>(stmtResultType)) {
471496
if (name == "min")
472497
return HlfirCharExtremumLowering{builder, loc,

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 94 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,33 @@ getIntrinsicEffects(mlir::Operation *self,
6767
}
6868
}
6969

70+
/// Verification helper for checking if two types are the same.
71+
/// Set \p allowCharacterLenMismatch to true, if character types
72+
/// of different known lengths should be treated as the same.
73+
template <typename Op>
74+
static llvm::LogicalResult areMatchingTypes(Op &op, mlir::Type type1,
75+
mlir::Type type2,
76+
bool allowCharacterLenMismatch) {
77+
if (auto charType1 = mlir::dyn_cast<fir::CharacterType>(type1))
78+
if (auto charType2 = mlir::dyn_cast<fir::CharacterType>(type2)) {
79+
// Character kinds must match.
80+
if (charType1.getFKind() != charType2.getFKind())
81+
return op.emitOpError("character KIND mismatch");
82+
83+
// Constant propagation can result in mismatching lengths
84+
// in the dead code, but we should not fail on this.
85+
if (!allowCharacterLenMismatch)
86+
if (charType1.getLen() != fir::CharacterType::unknownLen() &&
87+
charType2.getLen() != fir::CharacterType::unknownLen() &&
88+
charType1.getLen() != charType2.getLen())
89+
return op.emitOpError("character LEN mismatch");
90+
91+
return mlir::success();
92+
}
93+
94+
return type1 == type2 ? mlir::success() : mlir::failure();
95+
}
96+
7097
//===----------------------------------------------------------------------===//
7198
// DeclareOp
7299
//===----------------------------------------------------------------------===//
@@ -1360,23 +1387,12 @@ llvm::LogicalResult hlfir::CShiftOp::verify() {
13601387
mlir::Value shift = getShift();
13611388
mlir::Type shiftTy = hlfir::getFortranElementOrSequenceType(shift.getType());
13621389

1363-
if (eleTy != resultEleTy) {
1364-
if (mlir::isa<fir::CharacterType>(eleTy) &&
1365-
mlir::isa<fir::CharacterType>(resultEleTy)) {
1366-
auto eleCharTy = mlir::cast<fir::CharacterType>(eleTy);
1367-
auto resultCharTy = mlir::cast<fir::CharacterType>(resultEleTy);
1368-
if (eleCharTy.getFKind() != resultCharTy.getFKind())
1369-
return emitOpError("kind mismatch between input and output arrays");
1370-
if (eleCharTy.getLen() != fir::CharacterType::unknownLen() &&
1371-
resultCharTy.getLen() != fir::CharacterType::unknownLen() &&
1372-
eleCharTy.getLen() != resultCharTy.getLen())
1373-
return emitOpError(
1374-
"character LEN mismatch between input and output arrays");
1375-
} else {
1376-
return emitOpError(
1377-
"input and output arrays should have the same element type");
1378-
}
1379-
}
1390+
// TODO: turn allowCharacterLenMismatch into true.
1391+
if (auto match = areMatchingTypes(*this, eleTy, resultEleTy,
1392+
/*allowCharacterLenMismatch=*/false);
1393+
match.failed())
1394+
return emitOpError(
1395+
"input and output arrays should have the same element type");
13801396

13811397
if (arrayRank != resultRank)
13821398
return emitOpError("input and output arrays should have the same rank");
@@ -1444,6 +1460,67 @@ void hlfir::CShiftOp::getEffects(
14441460
getIntrinsicEffects(getOperation(), effects);
14451461
}
14461462

1463+
//===----------------------------------------------------------------------===//
1464+
// ReshapeOp
1465+
//===----------------------------------------------------------------------===//
1466+
1467+
llvm::LogicalResult hlfir::ReshapeOp::verify() {
1468+
auto results = getOperation()->getResultTypes();
1469+
assert(results.size() == 1);
1470+
hlfir::ExprType resultType = mlir::cast<hlfir::ExprType>(results[0]);
1471+
mlir::Value array = getArray();
1472+
auto arrayType = mlir::cast<fir::SequenceType>(
1473+
hlfir::getFortranElementOrSequenceType(array.getType()));
1474+
if (auto match = areMatchingTypes(
1475+
*this, hlfir::getFortranElementType(resultType),
1476+
arrayType.getElementType(), /*allowCharacterLenMismatch=*/true);
1477+
match.failed())
1478+
return emitOpError("ARRAY and the result must have the same element type");
1479+
if (hlfir::isPolymorphicType(resultType) !=
1480+
hlfir::isPolymorphicType(array.getType()))
1481+
return emitOpError("ARRAY must be polymorphic iff result is polymorphic");
1482+
1483+
mlir::Value shape = getShape();
1484+
auto shapeArrayType = mlir::cast<fir::SequenceType>(
1485+
hlfir::getFortranElementOrSequenceType(shape.getType()));
1486+
if (shapeArrayType.getDimension() != 1)
1487+
return emitOpError("SHAPE must be an array of rank 1");
1488+
if (!mlir::isa<mlir::IntegerType>(shapeArrayType.getElementType()))
1489+
return emitOpError("SHAPE must be an integer array");
1490+
if (shapeArrayType.hasDynamicExtents())
1491+
return emitOpError("SHAPE must have known size");
1492+
if (shapeArrayType.getConstantArraySize() != resultType.getRank())
1493+
return emitOpError("SHAPE's extent must match the result rank");
1494+
1495+
if (mlir::Value pad = getPad()) {
1496+
auto padArrayType = mlir::cast<fir::SequenceType>(
1497+
hlfir::getFortranElementOrSequenceType(pad.getType()));
1498+
if (auto match = areMatchingTypes(*this, arrayType.getElementType(),
1499+
padArrayType.getElementType(),
1500+
/*allowCharacterLenMismatch=*/true);
1501+
match.failed())
1502+
return emitOpError("ARRAY and PAD must be of the same type");
1503+
}
1504+
1505+
if (mlir::Value order = getOrder()) {
1506+
auto orderArrayType = mlir::cast<fir::SequenceType>(
1507+
hlfir::getFortranElementOrSequenceType(order.getType()));
1508+
if (orderArrayType.getDimension() != 1)
1509+
return emitOpError("ORDER must be an array of rank 1");
1510+
if (!mlir::isa<mlir::IntegerType>(orderArrayType.getElementType()))
1511+
return emitOpError("ORDER must be an integer array");
1512+
}
1513+
1514+
return mlir::success();
1515+
}
1516+
1517+
void hlfir::ReshapeOp::getEffects(
1518+
llvm::SmallVectorImpl<
1519+
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
1520+
&effects) {
1521+
getIntrinsicEffects(getOperation(), effects);
1522+
}
1523+
14471524
//===----------------------------------------------------------------------===//
14481525
// AssociateOp
14491526
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -494,20 +494,54 @@ class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> {
494494
}
495495
};
496496

497+
class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> {
498+
using HlfirIntrinsicConversion<hlfir::ReshapeOp>::HlfirIntrinsicConversion;
499+
500+
llvm::LogicalResult
501+
matchAndRewrite(hlfir::ReshapeOp reshape,
502+
mlir::PatternRewriter &rewriter) const override {
503+
fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
504+
const mlir::Location &loc = reshape->getLoc();
505+
506+
llvm::SmallVector<IntrinsicArgument, 4> inArgs;
507+
mlir::Value array = reshape.getArray();
508+
inArgs.push_back({array, array.getType()});
509+
mlir::Value shape = reshape.getShape();
510+
inArgs.push_back({shape, shape.getType()});
511+
mlir::Type noneType = builder.getNoneType();
512+
mlir::Value pad = reshape.getPad();
513+
inArgs.push_back({pad, pad ? pad.getType() : noneType});
514+
mlir::Value order = reshape.getOrder();
515+
inArgs.push_back({order, order ? order.getType() : noneType});
516+
517+
auto *argLowering = fir::getIntrinsicArgumentLowering("reshape");
518+
llvm::SmallVector<fir::ExtendedValue, 4> args =
519+
lowerArguments(reshape, inArgs, rewriter, argLowering);
520+
521+
mlir::Type scalarResultType =
522+
hlfir::getFortranElementType(reshape.getType());
523+
524+
auto [resultExv, mustBeFreed] =
525+
fir::genIntrinsicCall(builder, loc, "reshape", scalarResultType, args);
526+
527+
processReturnValue(reshape, resultExv, mustBeFreed, builder, rewriter);
528+
return mlir::success();
529+
}
530+
};
531+
497532
class LowerHLFIRIntrinsics
498533
: public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
499534
public:
500535
void runOnOperation() override {
501536
mlir::ModuleOp module = this->getOperation();
502537
mlir::MLIRContext *context = &getContext();
503538
mlir::RewritePatternSet patterns(context);
504-
patterns
505-
.insert<MatmulOpConversion, MatmulTransposeOpConversion,
506-
AllOpConversion, AnyOpConversion, SumOpConversion,
507-
ProductOpConversion, TransposeOpConversion, CountOpConversion,
508-
DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion,
509-
MinlocOpConversion, MaxlocOpConversion, CShiftOpConversion>(
510-
context);
539+
patterns.insert<
540+
MatmulOpConversion, MatmulTransposeOpConversion, AllOpConversion,
541+
AnyOpConversion, SumOpConversion, ProductOpConversion,
542+
TransposeOpConversion, CountOpConversion, DotProductOpConversion,
543+
MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
544+
MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context);
511545

512546
// While conceptually this pass is performing dialect conversion, we use
513547
// pattern rewrites here instead of dialect conversion because this pass

0 commit comments

Comments
 (0)