Skip to content

Commit 437b0c8

Browse files
committed
[flang] Lower REDUCE intrinsic for scalar result
1 parent 0605e98 commit 437b0c8

File tree

6 files changed

+803
-20
lines changed

6 files changed

+803
-20
lines changed

flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h

Lines changed: 177 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "flang/Optimizer/Builder/FIRBuilder.h"
2323
#include "flang/Optimizer/Dialect/FIRDialect.h"
2424
#include "flang/Optimizer/Dialect/FIRType.h"
25+
#include "flang/Runtime/reduce.h"
2526
#include "mlir/IR/BuiltinTypes.h"
2627
#include "mlir/IR/MLIRContext.h"
2728
#include "llvm/ADT/SmallVector.h"
@@ -52,6 +53,34 @@ namespace fir::runtime {
5253
using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
5354
using FuncTypeBuilderFunc = mlir::FunctionType (*)(mlir::MLIRContext *);
5455

56+
#define REDUCTION_OPERATION_MODEL(T) \
57+
template <> \
58+
constexpr TypeBuilderFunc \
59+
getModel<Fortran::runtime::ReductionOperation<T>>() { \
60+
return [](mlir::MLIRContext *context) -> mlir::Type { \
61+
TypeBuilderFunc f{getModel<T>()}; \
62+
auto refTy = fir::ReferenceType::get(f(context)); \
63+
return mlir::FunctionType::get(context, {refTy, refTy}, refTy); \
64+
}; \
65+
}
66+
67+
#define REDUCTION_CHAR_OPERATION_MODEL(T) \
68+
template <> \
69+
constexpr TypeBuilderFunc \
70+
getModel<Fortran::runtime::ReductionCharOperation<T>>() { \
71+
return [](mlir::MLIRContext *context) -> mlir::Type { \
72+
TypeBuilderFunc f{getModel<T>()}; \
73+
auto voidTy = fir::LLVMPointerType::get( \
74+
context, mlir::IntegerType::get(context, 8)); \
75+
auto size_tTy = \
76+
mlir::IntegerType::get(context, 8 * sizeof(std::size_t)); \
77+
auto refTy = fir::ReferenceType::get(f(context)); \
78+
return mlir::FunctionType::get( \
79+
context, {refTy, size_tTy, refTy, refTy, size_tTy, size_tTy}, \
80+
voidTy); \
81+
}; \
82+
}
83+
5584
//===----------------------------------------------------------------------===//
5685
// Type builder models
5786
//===----------------------------------------------------------------------===//
@@ -75,14 +104,24 @@ constexpr TypeBuilderFunc getModel<unsigned int>() {
75104
return mlir::IntegerType::get(context, 8 * sizeof(unsigned int));
76105
};
77106
}
78-
79107
template <>
80108
constexpr TypeBuilderFunc getModel<short int>() {
81109
return [](mlir::MLIRContext *context) -> mlir::Type {
82110
return mlir::IntegerType::get(context, 8 * sizeof(short int));
83111
};
84112
}
85113
template <>
114+
constexpr TypeBuilderFunc getModel<short int *>() {
115+
return [](mlir::MLIRContext *context) -> mlir::Type {
116+
TypeBuilderFunc f{getModel<short int>()};
117+
return fir::ReferenceType::get(f(context));
118+
};
119+
}
120+
template <>
121+
constexpr TypeBuilderFunc getModel<const short int *>() {
122+
return getModel<short int *>();
123+
}
124+
template <>
86125
constexpr TypeBuilderFunc getModel<int>() {
87126
return [](mlir::MLIRContext *context) -> mlir::Type {
88127
return mlir::IntegerType::get(context, 8 * sizeof(int));
@@ -96,6 +135,17 @@ constexpr TypeBuilderFunc getModel<int &>() {
96135
};
97136
}
98137
template <>
138+
constexpr TypeBuilderFunc getModel<int *>() {
139+
return getModel<int &>();
140+
}
141+
template <>
142+
constexpr TypeBuilderFunc getModel<const int *>() {
143+
return [](mlir::MLIRContext *context) -> mlir::Type {
144+
TypeBuilderFunc f{getModel<int>()};
145+
return fir::ReferenceType::get(f(context));
146+
};
147+
}
148+
template <>
99149
constexpr TypeBuilderFunc getModel<char *>() {
100150
return [](mlir::MLIRContext *context) -> mlir::Type {
101151
return fir::ReferenceType::get(mlir::IntegerType::get(context, 8));
@@ -130,6 +180,43 @@ constexpr TypeBuilderFunc getModel<signed char>() {
130180
};
131181
}
132182
template <>
183+
constexpr TypeBuilderFunc getModel<signed char *>() {
184+
return [](mlir::MLIRContext *context) -> mlir::Type {
185+
TypeBuilderFunc f{getModel<signed char>()};
186+
return fir::ReferenceType::get(f(context));
187+
};
188+
}
189+
template <>
190+
constexpr TypeBuilderFunc getModel<const signed char *>() {
191+
return getModel<signed char *>();
192+
}
193+
template <>
194+
constexpr TypeBuilderFunc getModel<char16_t>() {
195+
return [](mlir::MLIRContext *context) -> mlir::Type {
196+
return mlir::IntegerType::get(context, 8 * sizeof(char16_t));
197+
};
198+
}
199+
template <>
200+
constexpr TypeBuilderFunc getModel<char16_t *>() {
201+
return [](mlir::MLIRContext *context) -> mlir::Type {
202+
TypeBuilderFunc f{getModel<char16_t>()};
203+
return fir::ReferenceType::get(f(context));
204+
};
205+
}
206+
template <>
207+
constexpr TypeBuilderFunc getModel<char32_t>() {
208+
return [](mlir::MLIRContext *context) -> mlir::Type {
209+
return mlir::IntegerType::get(context, 8 * sizeof(char32_t));
210+
};
211+
}
212+
template <>
213+
constexpr TypeBuilderFunc getModel<char32_t *>() {
214+
return [](mlir::MLIRContext *context) -> mlir::Type {
215+
TypeBuilderFunc f{getModel<char32_t>()};
216+
return fir::ReferenceType::get(f(context));
217+
};
218+
}
219+
template <>
133220
constexpr TypeBuilderFunc getModel<unsigned char>() {
134221
return [](mlir::MLIRContext *context) -> mlir::Type {
135222
return mlir::IntegerType::get(context, 8 * sizeof(unsigned char));
@@ -175,6 +262,10 @@ constexpr TypeBuilderFunc getModel<long *>() {
175262
return getModel<long &>();
176263
}
177264
template <>
265+
constexpr TypeBuilderFunc getModel<const long *>() {
266+
return getModel<long *>();
267+
}
268+
template <>
178269
constexpr TypeBuilderFunc getModel<long long>() {
179270
return [](mlir::MLIRContext *context) -> mlir::Type {
180271
return mlir::IntegerType::get(context, 8 * sizeof(long long));
@@ -198,6 +289,7 @@ template <>
198289
constexpr TypeBuilderFunc getModel<long long *>() {
199290
return getModel<long long &>();
200291
}
292+
201293
template <>
202294
constexpr TypeBuilderFunc getModel<unsigned long>() {
203295
return [](mlir::MLIRContext *context) -> mlir::Type {
@@ -228,6 +320,27 @@ constexpr TypeBuilderFunc getModel<double *>() {
228320
return getModel<double &>();
229321
}
230322
template <>
323+
constexpr TypeBuilderFunc getModel<const double *>() {
324+
return getModel<double *>();
325+
}
326+
template <>
327+
constexpr TypeBuilderFunc getModel<long double>() {
328+
return [](mlir::MLIRContext *context) -> mlir::Type {
329+
return mlir::FloatType::getF80(context);
330+
};
331+
}
332+
template <>
333+
constexpr TypeBuilderFunc getModel<long double *>() {
334+
return [](mlir::MLIRContext *context) -> mlir::Type {
335+
TypeBuilderFunc f{getModel<long double>()};
336+
return fir::ReferenceType::get(f(context));
337+
};
338+
}
339+
template <>
340+
constexpr TypeBuilderFunc getModel<const long double *>() {
341+
return getModel<long double *>();
342+
}
343+
template <>
231344
constexpr TypeBuilderFunc getModel<float>() {
232345
return [](mlir::MLIRContext *context) -> mlir::Type {
233346
return mlir::FloatType::getF32(context);
@@ -245,6 +358,10 @@ constexpr TypeBuilderFunc getModel<float *>() {
245358
return getModel<float &>();
246359
}
247360
template <>
361+
constexpr TypeBuilderFunc getModel<const float *>() {
362+
return getModel<float *>();
363+
}
364+
template <>
248365
constexpr TypeBuilderFunc getModel<bool>() {
249366
return [](mlir::MLIRContext *context) -> mlir::Type {
250367
return mlir::IntegerType::get(context, 1);
@@ -258,20 +375,48 @@ constexpr TypeBuilderFunc getModel<bool &>() {
258375
};
259376
}
260377
template <>
378+
constexpr TypeBuilderFunc getModel<std::complex<float>>() {
379+
return [](mlir::MLIRContext *context) -> mlir::Type {
380+
return mlir::ComplexType::get(mlir::FloatType::getF32(context));
381+
};
382+
}
383+
template <>
261384
constexpr TypeBuilderFunc getModel<std::complex<float> &>() {
262385
return [](mlir::MLIRContext *context) -> mlir::Type {
263-
auto ty = mlir::ComplexType::get(mlir::FloatType::getF32(context));
264-
return fir::ReferenceType::get(ty);
386+
TypeBuilderFunc f{getModel<std::complex<float>>()};
387+
return fir::ReferenceType::get(f(context));
388+
};
389+
}
390+
template <>
391+
constexpr TypeBuilderFunc getModel<std::complex<float> *>() {
392+
return getModel<std::complex<float> &>();
393+
}
394+
template <>
395+
constexpr TypeBuilderFunc getModel<const std::complex<float> *>() {
396+
return getModel<std::complex<float> *>();
397+
}
398+
template <>
399+
constexpr TypeBuilderFunc getModel<std::complex<double>>() {
400+
return [](mlir::MLIRContext *context) -> mlir::Type {
401+
return mlir::ComplexType::get(mlir::FloatType::getF64(context));
265402
};
266403
}
267404
template <>
268405
constexpr TypeBuilderFunc getModel<std::complex<double> &>() {
269406
return [](mlir::MLIRContext *context) -> mlir::Type {
270-
auto ty = mlir::ComplexType::get(mlir::FloatType::getF64(context));
271-
return fir::ReferenceType::get(ty);
407+
TypeBuilderFunc f{getModel<std::complex<double>>()};
408+
return fir::ReferenceType::get(f(context));
272409
};
273410
}
274411
template <>
412+
constexpr TypeBuilderFunc getModel<std::complex<double> *>() {
413+
return getModel<std::complex<double> &>();
414+
}
415+
template <>
416+
constexpr TypeBuilderFunc getModel<const std::complex<double> *>() {
417+
return getModel<std::complex<double> *>();
418+
}
419+
template <>
275420
constexpr TypeBuilderFunc getModel<c_float_complex_t>() {
276421
return [](mlir::MLIRContext *context) -> mlir::Type {
277422
return fir::ComplexType::get(context, sizeof(float));
@@ -332,6 +477,33 @@ constexpr TypeBuilderFunc getModel<void>() {
332477
};
333478
}
334479

480+
REDUCTION_OPERATION_MODEL(std::int8_t)
481+
REDUCTION_OPERATION_MODEL(std::int16_t)
482+
REDUCTION_OPERATION_MODEL(std::int32_t)
483+
REDUCTION_OPERATION_MODEL(std::int64_t)
484+
REDUCTION_OPERATION_MODEL(Fortran::common::int128_t)
485+
486+
REDUCTION_OPERATION_MODEL(float)
487+
REDUCTION_OPERATION_MODEL(double)
488+
REDUCTION_OPERATION_MODEL(long double)
489+
490+
REDUCTION_OPERATION_MODEL(std::complex<float>)
491+
REDUCTION_OPERATION_MODEL(std::complex<double>)
492+
493+
REDUCTION_CHAR_OPERATION_MODEL(char)
494+
REDUCTION_CHAR_OPERATION_MODEL(char16_t)
495+
REDUCTION_CHAR_OPERATION_MODEL(char32_t)
496+
497+
template <>
498+
constexpr TypeBuilderFunc
499+
getModel<Fortran::runtime::ReductionDerivedTypeOperation>() {
500+
return [](mlir::MLIRContext *context) -> mlir::Type {
501+
auto voidTy =
502+
fir::LLVMPointerType::get(context, mlir::IntegerType::get(context, 8));
503+
return mlir::FunctionType::get(context, {voidTy, voidTy, voidTy}, voidTy);
504+
};
505+
}
506+
335507
template <typename...>
336508
struct RuntimeTableKey;
337509
template <typename RT, typename... ATs>

flang/include/flang/Optimizer/Builder/Runtime/Reduction.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,14 @@ void genIParityDim(fir::FirOpBuilder &builder, mlir::Location loc,
224224
mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim,
225225
mlir::Value maskBox);
226226

227+
/// Generate call to `Reduce` intrinsic runtime routine. This is the version
228+
/// that does not take a dim argument.
229+
mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
230+
mlir::Value arrayBox, mlir::Value operation,
231+
mlir::Value dim, mlir::Value maskBox,
232+
mlir::Value identity, mlir::Value ordered,
233+
mlir::Value resultBox);
234+
227235
} // namespace fir::runtime
228236

229237
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_REDUCTION_H

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ static constexpr IntrinsicHandler handlers[]{
526526
{"operation", asAddr},
527527
{"dim", asValue},
528528
{"mask", asBox, handleDynamicOptional},
529-
{"identity", asValue},
529+
{"identity", asAddr},
530530
{"ordered", asValue}}},
531531
/*isElemental=*/false},
532532
{"repeat",
@@ -5736,7 +5736,66 @@ void IntrinsicLibrary::genRandomSeed(llvm::ArrayRef<fir::ExtendedValue> args) {
57365736
fir::ExtendedValue
57375737
IntrinsicLibrary::genReduce(mlir::Type resultType,
57385738
llvm::ArrayRef<fir::ExtendedValue> args) {
5739-
TODO(loc, "intrinsic: reduce");
5739+
assert(args.size() == 6);
5740+
5741+
fir::BoxValue arrayTmp = builder.createBox(loc, args[0]);
5742+
mlir::Value array = fir::getBase(arrayTmp);
5743+
mlir::Value operation = fir::getBase(args[1]);
5744+
int rank = arrayTmp.rank();
5745+
assert(rank >= 1);
5746+
5747+
mlir::Type ty = array.getType();
5748+
mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
5749+
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
5750+
5751+
// Handle optional mask argument
5752+
auto dim = isStaticallyAbsent(args[3])
5753+
? builder.createIntegerConstant(loc, builder.getI32Type(), 1)
5754+
: fir::getBase(args[2]);
5755+
5756+
auto mask = isStaticallyAbsent(args[3])
5757+
? builder.create<fir::AbsentOp>(
5758+
loc, fir::BoxType::get(builder.getI1Type()))
5759+
: builder.createBox(loc, args[3]);
5760+
5761+
mlir::Value identity =
5762+
isStaticallyAbsent(args[4])
5763+
? builder.create<fir::AbsentOp>(loc, fir::ReferenceType::get(eleTy))
5764+
: fir::getBase(args[4]);
5765+
5766+
mlir::Value ordered = isStaticallyAbsent(args[5])
5767+
? builder.createBool(loc, true)
5768+
: fir::getBase(args[5]);
5769+
5770+
// We call the type specific versions because the result is scalar
5771+
// in the case below.
5772+
if (rank == 1) {
5773+
if (fir::isa_complex(eleTy) || fir::isa_derived(eleTy)) {
5774+
mlir::Value result = builder.createTemporary(loc, eleTy);
5775+
fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
5776+
identity, ordered, result);
5777+
if (fir::isa_derived(eleTy))
5778+
return result;
5779+
return builder.create<fir::LoadOp>(loc, result);
5780+
}
5781+
if (fir::isa_char(eleTy)) {
5782+
// Create mutable fir.box to be passed to the runtime for the result.
5783+
fir::MutableBoxValue resultMutableBox =
5784+
fir::factory::createTempMutableBox(builder, loc, eleTy);
5785+
mlir::Value resultIrBox =
5786+
fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
5787+
fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
5788+
identity, ordered, resultIrBox);
5789+
// Handle cleanup of allocatable result descriptor and return
5790+
return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
5791+
}
5792+
auto resultBox = builder.create<fir::AbsentOp>(
5793+
loc, fir::BoxType::get(builder.getI1Type()));
5794+
return fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
5795+
identity, ordered, resultBox);
5796+
}
5797+
5798+
TODO(loc, "intrinsic: reduce with non scalar result");
57405799
}
57415800

57425801
// REPEAT

0 commit comments

Comments
 (0)