Skip to content

Commit 87a01d4

Browse files
committed
[flang] Lower REDUCE intrinsic with DIM argument
1 parent 8d713c0 commit 87a01d4

File tree

4 files changed

+443
-1
lines changed

4 files changed

+443
-1
lines changed

flang/include/flang/Optimizer/Builder/Runtime/Reduction.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
240240
mlir::Value maskBox, mlir::Value identity,
241241
mlir::Value ordered);
242242

243+
/// Generate call to `Reduce` intrinsic runtime routine. This is the version
244+
/// that takes arrays of any rank with a dim argument specified.
245+
void genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
246+
mlir::Value arrayBox, mlir::Value operation, mlir::Value dim,
247+
mlir::Value maskBox, mlir::Value identity,
248+
mlir::Value ordered, mlir::Value resultBox);
249+
243250
} // namespace fir::runtime
244251

245252
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_REDUCTION_H

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5790,7 +5790,17 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
57905790
return fir::runtime::genReduce(builder, loc, array, operation, mask,
57915791
identity, ordered);
57925792
}
5793-
TODO(loc, "reduce with array result");
5793+
// Handle cases that have an array result.
5794+
// Create mutable fir.box to be passed to the runtime for the result.
5795+
mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, rank - 1);
5796+
fir::MutableBoxValue resultMutableBox =
5797+
fir::factory::createTempMutableBox(builder, loc, resultArrayType);
5798+
mlir::Value resultIrBox =
5799+
fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
5800+
mlir::Value dim = fir::getBase(args[2]);
5801+
fir::runtime::genReduceDim(builder, loc, array, operation, dim, mask,
5802+
identity, ordered, resultIrBox);
5803+
return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
57945804
}
57955805

57965806
// REPEAT

flang/lib/Optimizer/Builder/Runtime/Reduction.cpp

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,50 @@ struct ForcedReduceReal16 {
486486
}
487487
};
488488

489+
/// Placeholder for DIM real*10 version of Reduce Intrinsic
490+
struct ForcedReduceReal10Dim {
491+
static constexpr const char *name =
492+
ExpandAndQuoteKey(RTNAME(ReduceReal10Dim));
493+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
494+
return [](mlir::MLIRContext *ctx) {
495+
auto ty = mlir::FloatType::getF80(ctx);
496+
auto boxTy =
497+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
498+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
499+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
500+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
501+
auto refTy = fir::ReferenceType::get(ty);
502+
auto refBoxTy = fir::ReferenceType::get(boxTy);
503+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
504+
return mlir::FunctionType::get(
505+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
506+
{});
507+
};
508+
}
509+
};
510+
511+
/// Placeholder for DIM real*16 version of Reduce Intrinsic
512+
struct ForcedReduceReal16Dim {
513+
static constexpr const char *name =
514+
ExpandAndQuoteKey(RTNAME(ReduceReal16Dim));
515+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
516+
return [](mlir::MLIRContext *ctx) {
517+
auto ty = mlir::FloatType::getF128(ctx);
518+
auto boxTy =
519+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
520+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
521+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
522+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
523+
auto refTy = fir::ReferenceType::get(ty);
524+
auto refBoxTy = fir::ReferenceType::get(boxTy);
525+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
526+
return mlir::FunctionType::get(
527+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
528+
{});
529+
};
530+
}
531+
};
532+
489533
/// Placeholder for integer*16 version of Reduce Intrinsic
490534
struct ForcedReduceInteger16 {
491535
static constexpr const char *name =
@@ -506,6 +550,28 @@ struct ForcedReduceInteger16 {
506550
}
507551
};
508552

553+
/// Placeholder for DIM integer*16 version of Reduce Intrinsic
554+
struct ForcedReduceInteger16Dim {
555+
static constexpr const char *name =
556+
ExpandAndQuoteKey(RTNAME(ReduceInteger16Dim));
557+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
558+
return [](mlir::MLIRContext *ctx) {
559+
auto ty = mlir::IntegerType::get(ctx, 128);
560+
auto boxTy =
561+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
562+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
563+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
564+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
565+
auto refTy = fir::ReferenceType::get(ty);
566+
auto refBoxTy = fir::ReferenceType::get(boxTy);
567+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
568+
return mlir::FunctionType::get(
569+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
570+
{});
571+
};
572+
}
573+
};
574+
509575
/// Placeholder for complex(10) version of Reduce Intrinsic
510576
struct ForcedReduceComplex10 {
511577
static constexpr const char *name =
@@ -527,6 +593,28 @@ struct ForcedReduceComplex10 {
527593
}
528594
};
529595

596+
/// Placeholder for Dim complex(10) version of Reduce Intrinsic
597+
struct ForcedReduceComplex10Dim {
598+
static constexpr const char *name =
599+
ExpandAndQuoteKey(RTNAME(CppReduceComplex10Dim));
600+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
601+
return [](mlir::MLIRContext *ctx) {
602+
auto ty = mlir::ComplexType::get(mlir::FloatType::getF80(ctx));
603+
auto boxTy =
604+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
605+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
606+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
607+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
608+
auto refTy = fir::ReferenceType::get(ty);
609+
auto refBoxTy = fir::ReferenceType::get(boxTy);
610+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
611+
return mlir::FunctionType::get(
612+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
613+
{});
614+
};
615+
}
616+
};
617+
530618
/// Placeholder for complex(16) version of Reduce Intrinsic
531619
struct ForcedReduceComplex16 {
532620
static constexpr const char *name =
@@ -548,6 +636,28 @@ struct ForcedReduceComplex16 {
548636
}
549637
};
550638

639+
/// Placeholder for Dim complex(16) version of Reduce Intrinsic
640+
struct ForcedReduceComplex16Dim {
641+
static constexpr const char *name =
642+
ExpandAndQuoteKey(RTNAME(CppReduceComplex16Dim));
643+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
644+
return [](mlir::MLIRContext *ctx) {
645+
auto ty = mlir::ComplexType::get(mlir::FloatType::getF128(ctx));
646+
auto boxTy =
647+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
648+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
649+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
650+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
651+
auto refTy = fir::ReferenceType::get(ty);
652+
auto refBoxTy = fir::ReferenceType::get(boxTy);
653+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
654+
return mlir::FunctionType::get(
655+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
656+
{});
657+
};
658+
}
659+
};
660+
551661
/// Generate call to specialized runtime function that takes a mask and
552662
/// dim argument. The All, Any, and Count intrinsics use this pattern.
553663
template <typename FN>
@@ -1442,3 +1552,97 @@ mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
14421552
maskBox, identity, ordered);
14431553
return builder.create<fir::CallOp>(loc, func, args).getResult(0);
14441554
}
1555+
1556+
void fir::runtime::genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
1557+
mlir::Value arrayBox, mlir::Value operation,
1558+
mlir::Value dim, mlir::Value maskBox,
1559+
mlir::Value identity, mlir::Value ordered,
1560+
mlir::Value resultBox) {
1561+
mlir::func::FuncOp func;
1562+
auto ty = arrayBox.getType();
1563+
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
1564+
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
1565+
1566+
mlir::MLIRContext *ctx = builder.getContext();
1567+
fir::factory::CharacterExprHelper charHelper{builder, loc};
1568+
1569+
if (eleTy.isF16())
1570+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal2Dim)>(loc, builder);
1571+
else if (eleTy.isBF16())
1572+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal3Dim)>(loc, builder);
1573+
else if (eleTy.isF32())
1574+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal4Dim)>(loc, builder);
1575+
else if (eleTy.isF64())
1576+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal8Dim)>(loc, builder);
1577+
else if (eleTy.isF80())
1578+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal10Dim)>(loc, builder);
1579+
else if (eleTy.isF128())
1580+
func = fir::runtime::getRuntimeFunc<ForcedReduceReal16Dim>(loc, builder);
1581+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
1582+
func =
1583+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger1Dim)>(loc, builder);
1584+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
1585+
func =
1586+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger2Dim)>(loc, builder);
1587+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
1588+
func =
1589+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger4Dim)>(loc, builder);
1590+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
1591+
func =
1592+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger8Dim)>(loc, builder);
1593+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
1594+
func = fir::runtime::getRuntimeFunc<ForcedReduceInteger16Dim>(loc, builder);
1595+
else if (eleTy == fir::ComplexType::get(ctx, 2))
1596+
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex2Dim)>(loc,
1597+
builder);
1598+
else if (eleTy == fir::ComplexType::get(ctx, 3))
1599+
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex3Dim)>(loc,
1600+
builder);
1601+
else if (eleTy == fir::ComplexType::get(ctx, 4))
1602+
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex4Dim)>(loc,
1603+
builder);
1604+
else if (eleTy == fir::ComplexType::get(ctx, 8))
1605+
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex8Dim)>(loc,
1606+
builder);
1607+
else if (eleTy == fir::ComplexType::get(ctx, 10))
1608+
func = fir::runtime::getRuntimeFunc<ForcedReduceComplex10Dim>(loc, builder);
1609+
else if (eleTy == fir::ComplexType::get(ctx, 16))
1610+
func = fir::runtime::getRuntimeFunc<ForcedReduceComplex16Dim>(loc, builder);
1611+
else if (eleTy == fir::LogicalType::get(ctx, 1))
1612+
func =
1613+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical1Dim)>(loc, builder);
1614+
else if (eleTy == fir::LogicalType::get(ctx, 2))
1615+
func =
1616+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical2Dim)>(loc, builder);
1617+
else if (eleTy == fir::LogicalType::get(ctx, 4))
1618+
func =
1619+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical4Dim)>(loc, builder);
1620+
else if (eleTy == fir::LogicalType::get(ctx, 8))
1621+
func =
1622+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical8Dim)>(loc, builder);
1623+
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 1)
1624+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter1Dim)>(loc,
1625+
builder);
1626+
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 2)
1627+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter2Dim)>(loc,
1628+
builder);
1629+
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 4)
1630+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter4Dim)>(loc,
1631+
builder);
1632+
else if (fir::isa_derived(eleTy))
1633+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceDerivedTypeDim)>(loc,
1634+
builder);
1635+
else
1636+
fir::intrinsicTypeTODO(builder, eleTy, loc, "REDUCE");
1637+
1638+
auto fTy = func.getFunctionType();
1639+
auto sourceFile = fir::factory::locationToFilename(builder, loc);
1640+
1641+
auto sourceLine =
1642+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
1643+
auto opAddr = builder.create<fir::BoxAddrOp>(loc, fTy.getInput(2), operation);
1644+
auto args = fir::runtime::createArguments(
1645+
builder, loc, fTy, resultBox, arrayBox, opAddr, sourceFile, sourceLine,
1646+
dim, maskBox, identity, ordered);
1647+
builder.create<fir::CallOp>(loc, func, args);
1648+
}

0 commit comments

Comments
 (0)