Skip to content

[flang][NFC] simplify dispatching of reduction runtime calls #110479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2024

Conversation

jeanPerier
Copy link
Contributor

As part of the RFC to replace fir.complex usages by mlir complex, this patch updates the type dispatch in Reduction.cpp to use macros to avoid naming the types everywhere and to avoid typos when copy-pasting the if/else chains.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Sep 30, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (jeanPerier)

Changes

As part of the RFC to replace fir.complex usages by mlir complex, this patch updates the type dispatch in Reduction.cpp to use macros to avoid naming the types everywhere and to avoid typos when copy-pasting the if/else chains.


Patch is 45.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110479.diff

2 Files Affected:

  • (modified) flang/include/flang/Optimizer/Support/Utils.h (+4)
  • (modified) flang/lib/Optimizer/Builder/Runtime/Reduction.cpp (+154-464)
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index 02bec4164fca08..06cf9c0be157c3 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -185,6 +185,10 @@ mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
     return {Fortran::common::TypeCategory::Complex, complexType.getFKind()};
   else if (auto logicalType = mlir::dyn_cast<fir::LogicalType>(type))
     return {Fortran::common::TypeCategory::Logical, logicalType.getFKind()};
+  else if (auto charType = mlir::dyn_cast<fir::CharacterType>(type))
+    return {Fortran::common::TypeCategory::Character, charType.getFKind()};
+  else if (mlir::isa<fir::RecordType>(type))
+    return {Fortran::common::TypeCategory::Derived, 0};
   else
     fir::emitFatalError(loc,
                         "unsupported type: " + fir::mlirTypeToString(type));
diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index 18eff937278562..b39824428c78a9 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -120,6 +120,16 @@ struct ForcedMinvalInteger16 {
   }
 };
 
+// Maxloc/Minloc take descriptor, so these runtime signature are not ifdef
+// and the mkRTKey can safely be used here. Define alias so that the
+// REAL_INTRINSIC_INSTANCES macro works with them too
+using ForcedMaxlocReal10 = mkRTKey(MaxlocReal10);
+using ForcedMaxlocReal16 = mkRTKey(MaxlocReal16);
+using ForcedMaxlocInteger16 = mkRTKey(MaxlocInteger16);
+using ForcedMinlocReal10 = mkRTKey(MinlocReal10);
+using ForcedMinlocReal16 = mkRTKey(MinlocReal16);
+using ForcedMinlocInteger16 = mkRTKey(MinlocInteger16);
+
 /// Placeholder for real*10 version of Norm2 Intrinsic
 struct ForcedNorm2Real10 {
   static constexpr const char *name = ExpandAndQuoteKey(RTNAME(Norm2_10));
@@ -468,7 +478,7 @@ struct ForcedIParity16 {
 };
 
 /// Placeholder for real*10 version of Reduce Intrinsic
-struct ForcedReduceReal10 {
+struct ForcedReduceReal10Ref {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(ReduceReal10Ref));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -508,7 +518,7 @@ struct ForcedReduceReal10Value {
 };
 
 /// Placeholder for real*16 version of Reduce Intrinsic
-struct ForcedReduceReal16 {
+struct ForcedReduceReal16Ref {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(ReduceReal16Ref));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -548,7 +558,7 @@ struct ForcedReduceReal16Value {
 };
 
 /// Placeholder for DIM real*10 version of Reduce Intrinsic
-struct ForcedReduceReal10Dim {
+struct ForcedReduceReal10DimRef {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(ReduceReal10DimRef));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -592,7 +602,7 @@ struct ForcedReduceReal10DimValue {
 };
 
 /// Placeholder for DIM real*16 version of Reduce Intrinsic
-struct ForcedReduceReal16Dim {
+struct ForcedReduceReal16DimRef {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(ReduceReal16DimRef));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -636,7 +646,7 @@ struct ForcedReduceReal16DimValue {
 };
 
 /// Placeholder for integer*16 version of Reduce Intrinsic
-struct ForcedReduceInteger16 {
+struct ForcedReduceInteger16Ref {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(ReduceInteger16Ref));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -676,7 +686,7 @@ struct ForcedReduceInteger16Value {
 };
 
 /// Placeholder for DIM integer*16 version of Reduce Intrinsic
-struct ForcedReduceInteger16Dim {
+struct ForcedReduceInteger16DimRef {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(ReduceInteger16DimRef));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -720,7 +730,7 @@ struct ForcedReduceInteger16DimValue {
 };
 
 /// Placeholder for complex(10) version of Reduce Intrinsic
-struct ForcedReduceComplex10 {
+struct ForcedReduceComplex10Ref {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(CppReduceComplex10Ref));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -762,7 +772,7 @@ struct ForcedReduceComplex10Value {
 };
 
 /// Placeholder for Dim complex(10) version of Reduce Intrinsic
-struct ForcedReduceComplex10Dim {
+struct ForcedReduceComplex10DimRef {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(CppReduceComplex10DimRef));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -806,7 +816,7 @@ struct ForcedReduceComplex10DimValue {
 };
 
 /// Placeholder for complex(16) version of Reduce Intrinsic
-struct ForcedReduceComplex16 {
+struct ForcedReduceComplex16Ref {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(CppReduceComplex16Ref));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -848,7 +858,7 @@ struct ForcedReduceComplex16Value {
 };
 
 /// Placeholder for Dim complex(16) version of Reduce Intrinsic
-struct ForcedReduceComplex16Dim {
+struct ForcedReduceComplex16DimRef {
   static constexpr const char *name =
       ExpandAndQuoteKey(RTNAME(CppReduceComplex16DimRef));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
@@ -891,6 +901,63 @@ struct ForcedReduceComplex16DimValue {
   }
 };
 
+#define INTRINSIC_INSTANCE(NAME, CAT, KIND, SUFFIX)                            \
+  if (!func && cat == TypeCategory::CAT && kind == KIND) {                     \
+    func = fir::runtime::getRuntimeFunc<mkRTKey(NAME##CAT##KIND##SUFFIX)>(     \
+        loc, builder);                                                         \
+  }
+#define FORCED_INTRINSIC_INSTANCE(NAME, CAT, KIND, SUFFIX)                     \
+  if (!func && cat == TypeCategory::CAT && kind == KIND) {                     \
+    func = fir::runtime::getRuntimeFunc<Forced##NAME##CAT##KIND##SUFFIX>(      \
+        loc, builder);                                                         \
+  }
+
+#define INTEGER_INTRINSIC_INSTANCES(NAME, SUFFIX)                              \
+  INTRINSIC_INSTANCE(NAME, Integer, 1, SUFFIX)                                 \
+  INTRINSIC_INSTANCE(NAME, Integer, 2, SUFFIX)                                 \
+  INTRINSIC_INSTANCE(NAME, Integer, 4, SUFFIX)                                 \
+  INTRINSIC_INSTANCE(NAME, Integer, 8, SUFFIX)                                 \
+  FORCED_INTRINSIC_INSTANCE(NAME, Integer, 16, SUFFIX)
+
+#define REAL_INTRINSIC_INSTANCES(NAME, SUFFIX)                                 \
+  INTRINSIC_INSTANCE(NAME, Real, 4, SUFFIX)                                    \
+  INTRINSIC_INSTANCE(NAME, Real, 8, SUFFIX)                                    \
+  FORCED_INTRINSIC_INSTANCE(NAME, Real, 10, SUFFIX)                            \
+  FORCED_INTRINSIC_INSTANCE(NAME, Real, 16, SUFFIX)
+
+#define COMPLEX_INTRINSIC_INSTANCES(NAME, SUFFIX)                              \
+  INTRINSIC_INSTANCE(Cpp##NAME, Complex, 4, SUFFIX)                            \
+  INTRINSIC_INSTANCE(Cpp##NAME, Complex, 8, SUFFIX)                            \
+  FORCED_INTRINSIC_INSTANCE(NAME, Complex, 10, SUFFIX)                         \
+  FORCED_INTRINSIC_INSTANCE(NAME, Complex, 16, SUFFIX)
+
+#define NUMERICAL_INTRINSIC_INSTANCES(NAME)                                    \
+  INTEGER_INTRINSIC_INSTANCES(NAME, )                                          \
+  REAL_INTRINSIC_INSTANCES(NAME, )                                             \
+  COMPLEX_INTRINSIC_INSTANCES(NAME, )
+
+#define LOGICAL_INTRINSIC_INSTANCES(NAME, SUFFIX)                              \
+  INTRINSIC_INSTANCE(NAME, Logical, 1, SUFFIX)                                 \
+  INTRINSIC_INSTANCE(NAME, Logical, 2, SUFFIX)                                 \
+  INTRINSIC_INSTANCE(NAME, Logical, 4, SUFFIX)                                 \
+  INTRINSIC_INSTANCE(NAME, Logical, 8, SUFFIX)
+
+#define NUMERICAL_AND_LOGICAL_INSTANCES(NAME, SUFFIX)                          \
+  INTEGER_INTRINSIC_INSTANCES(NAME, SUFFIX)                                    \
+  REAL_INTRINSIC_INSTANCES(NAME, SUFFIX)                                       \
+  COMPLEX_INTRINSIC_INSTANCES(NAME, SUFFIX)                                    \
+  LOGICAL_INTRINSIC_INSTANCES(NAME, SUFFIX)
+
+// REAL/COMPLEX 2 and 3 usually have no runtime implementation, so they have
+// special macros.
+#define REAL_2_3_INTRINSIC_INSTANCES(NAME, SUFFIX)                             \
+  INTRINSIC_INSTANCE(NAME, Real, 2, SUFFIX)                                    \
+  INTRINSIC_INSTANCE(NAME, Real, 3, SUFFIX)
+
+#define COMPLEX_2_3_INTRINSIC_INSTANCES(NAME, SUFFIX)                          \
+  INTRINSIC_INSTANCE(Cpp##NAME, Complex, 2, SUFFIX)                            \
+  INTRINSIC_INSTANCE(Cpp##NAME, Complex, 3, SUFFIX)
+
 /// Generate call to specialized runtime function that takes a mask and
 /// dim argument. The All, Any, and Count intrinsics use this pattern.
 template <typename FN>
@@ -1086,36 +1153,21 @@ void fir::runtime::genFindlocDim(fir::FirOpBuilder &builder, mlir::Location loc,
 /// that does not take a dim argument.
 void fir::runtime::genMaxloc(fir::FirOpBuilder &builder, mlir::Location loc,
                              mlir::Value resultBox, mlir::Value arrayBox,
-                             mlir::Value maskBox, mlir::Value kind,
+                             mlir::Value maskBox, mlir::Value kindVal,
                              mlir::Value back) {
-  mlir::func::FuncOp func;
   auto ty = arrayBox.getType();
   auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
   auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
   fir::factory::CharacterExprHelper charHelper{builder, loc};
-  if (eleTy.isF32())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocReal4)>(loc, builder);
-  else if (eleTy.isF64())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocReal8)>(loc, builder);
-  else if (eleTy.isF80())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocReal10)>(loc, builder);
-  else if (eleTy.isF128())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocReal16)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocInteger1)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocInteger2)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocInteger4)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocInteger8)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocInteger16)>(loc, builder);
-  else if (charHelper.isCharacterScalar(eleTy))
+  auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
+  mlir::func::FuncOp func;
+  REAL_INTRINSIC_INSTANCES(Maxloc, )
+  INTEGER_INTRINSIC_INSTANCES(Maxloc, )
+  if (charHelper.isCharacterScalar(eleTy))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MaxlocCharacter)>(loc, builder);
-  else
+  if (!func)
     fir::intrinsicTypeTODO(builder, eleTy, loc, "MAXLOC");
-  genReduction4Args(func, builder, loc, resultBox, arrayBox, maskBox, kind,
+  genReduction4Args(func, builder, loc, resultBox, arrayBox, maskBox, kindVal,
                     back);
 }
 
@@ -1135,31 +1187,15 @@ void fir::runtime::genMaxlocDim(fir::FirOpBuilder &builder, mlir::Location loc,
 mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
                                     mlir::Location loc, mlir::Value arrayBox,
                                     mlir::Value maskBox) {
-  mlir::func::FuncOp func;
   auto ty = arrayBox.getType();
   auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
   auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
   auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
-
-  if (eleTy.isF32())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal4)>(loc, builder);
-  else if (eleTy.isF64())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal8)>(loc, builder);
-  else if (eleTy.isF80())
-    func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal10>(loc, builder);
-  else if (eleTy.isF128())
-    func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal16>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger1)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger2)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger4)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger8)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
-    func = fir::runtime::getRuntimeFunc<ForcedMaxvalInteger16>(loc, builder);
-  else
+  auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
+  mlir::func::FuncOp func;
+  REAL_INTRINSIC_INSTANCES(Maxval, )
+  INTEGER_INTRINSIC_INSTANCES(Maxval, )
+  if (!func)
     fir::intrinsicTypeTODO(builder, eleTy, loc, "MAXVAL");
 
   auto fTy = func.getFunctionType();
@@ -1201,36 +1237,21 @@ void fir::runtime::genMaxvalChar(fir::FirOpBuilder &builder, mlir::Location loc,
 /// that does not take a dim argument.
 void fir::runtime::genMinloc(fir::FirOpBuilder &builder, mlir::Location loc,
                              mlir::Value resultBox, mlir::Value arrayBox,
-                             mlir::Value maskBox, mlir::Value kind,
+                             mlir::Value maskBox, mlir::Value kindVal,
                              mlir::Value back) {
-  mlir::func::FuncOp func;
   auto ty = arrayBox.getType();
   auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
   auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
+  auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
+  mlir::func::FuncOp func;
+  REAL_INTRINSIC_INSTANCES(Minloc, )
+  INTEGER_INTRINSIC_INSTANCES(Minloc, )
   fir::factory::CharacterExprHelper charHelper{builder, loc};
-  if (eleTy.isF32())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocReal4)>(loc, builder);
-  else if (eleTy.isF64())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocReal8)>(loc, builder);
-  else if (eleTy.isF80())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocReal10)>(loc, builder);
-  else if (eleTy.isF128())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocReal16)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocInteger1)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocInteger2)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocInteger4)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocInteger8)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocInteger16)>(loc, builder);
-  else if (charHelper.isCharacterScalar(eleTy))
+  if (charHelper.isCharacterScalar(eleTy))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MinlocCharacter)>(loc, builder);
-  else
+  if (!func)
     fir::intrinsicTypeTODO(builder, eleTy, loc, "MINLOC");
-  genReduction4Args(func, builder, loc, resultBox, arrayBox, maskBox, kind,
+  genReduction4Args(func, builder, loc, resultBox, arrayBox, maskBox, kindVal,
                     back);
 }
 
@@ -1275,31 +1296,16 @@ void fir::runtime::genMinvalChar(fir::FirOpBuilder &builder, mlir::Location loc,
 mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
                                     mlir::Location loc, mlir::Value arrayBox,
                                     mlir::Value maskBox) {
-  mlir::func::FuncOp func;
   auto ty = arrayBox.getType();
   auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
   auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
   auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
+  auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
 
-  if (eleTy.isF32())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal4)>(loc, builder);
-  else if (eleTy.isF64())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal8)>(loc, builder);
-  else if (eleTy.isF80())
-    func = fir::runtime::getRuntimeFunc<ForcedMinvalReal10>(loc, builder);
-  else if (eleTy.isF128())
-    func = fir::runtime::getRuntimeFunc<ForcedMinvalReal16>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger1)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger2)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger4)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger8)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
-    func = fir::runtime::getRuntimeFunc<ForcedMinvalInteger16>(loc, builder);
-  else
+  mlir::func::FuncOp func;
+  REAL_INTRINSIC_INSTANCES(Minval, )
+  INTEGER_INTRINSIC_INSTANCES(Minval, )
+  if (!func)
     fir::intrinsicTypeTODO(builder, eleTy, loc, "MINVAL");
 
   auto fTy = func.getFunctionType();
@@ -1390,41 +1396,15 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
                                      mlir::Location loc, mlir::Value arrayBox,
                                      mlir::Value maskBox,
                                      mlir::Value resultBox) {
-  mlir::func::FuncOp func;
   auto ty = arrayBox.getType();
   auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
   auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
   auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
 
-  if (eleTy.isF32())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal4)>(loc, builder);
-  else if (eleTy.isF64())
-    func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal8)>(loc, builder);
-  else if (eleTy.isF80())
-    func = fir::runtime::getRuntimeFunc<ForcedProductReal10>(loc, builder);
-  else if (eleTy.isF128())
-    func = fir::runtime::getRuntimeFunc<ForcedProductReal16>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger1)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger2)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger4)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
-    func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger8)>(loc, builder);
-  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
-    func = fir::runtime::getRuntimeFunc<ForcedProductInteger16>(loc, builder);
-  else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
-    func =
-        fir::runtime::getRuntimeFunc<mkRTKey(CppProductComplex4)>(loc, builder);
-  else if (eleTy == fir::ComplexType::get(builder.getContext(), 8))
-    func =
-        fir::runtime::getRuntimeFunc<mkRTKey(CppProductComplex8)>(loc, builder);
-  else if (eleTy == fir::ComplexType::get(builder.getContext(), 10))
-    func = fir::runtime...
[truncated]

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cleanup

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeanPerier jeanPerier merged commit cc3cc5e into llvm:main Oct 1, 2024
11 checks passed
@jeanPerier jeanPerier deleted the simplify-reduction-dispatch branch October 1, 2024 08:10
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
…0479)

As part of t[he RFC to replace fir.complex usages by mlir
complex](https://discourse.llvm.org/t/rfc-flang-replace-usages-of-fir-complex-by-mlir-complex-type/82292),
this patch updates the type dispatch in Reduction.cpp to use macros to
avoid naming the types everywhere and to avoid typos when copy-pasting
the if/else chains.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants