14
14
15
15
#include " flang/Lower/AbstractConverter.h"
16
16
#include " flang/Optimizer/Builder/Todo.h"
17
+ #include " flang/Optimizer/Dialect/FIRType.h"
17
18
#include " flang/Optimizer/HLFIR/HLFIROps.h"
18
19
#include " flang/Parser/tools.h"
19
20
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
21
+ #include " llvm/Support/CommandLine.h"
22
+
23
+ static llvm::cl::opt<bool > forceByrefReduction (
24
+ " force-byref-reduction" ,
25
+ llvm::cl::desc (" Pass all reduction arguments by reference" ),
26
+ llvm::cl::Hidden);
20
27
21
28
namespace Fortran {
22
29
namespace lower {
@@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
76
83
}
77
84
78
85
std::string ReductionProcessor::getReductionName (llvm::StringRef name,
79
- mlir::Type ty) {
86
+ mlir::Type ty, bool isByRef) {
87
+ ty = fir::unwrapRefType (ty);
88
+
89
+ // extra string to distinguish reduction functions for variables passed by
90
+ // reference
91
+ llvm::StringRef byrefAddition{" " };
92
+ if (isByRef)
93
+ byrefAddition = " _byref" ;
94
+
80
95
return (llvm::Twine (name) +
81
96
(ty.isIntOrIndex () ? llvm::Twine (" _i_" ) : llvm::Twine (" _f_" )) +
82
- llvm::Twine (ty.getIntOrFloatBitWidth ()))
97
+ llvm::Twine (ty.getIntOrFloatBitWidth ()) + byrefAddition )
83
98
.str ();
84
99
}
85
100
86
101
std::string ReductionProcessor::getReductionName (
87
102
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
88
- mlir::Type ty) {
103
+ mlir::Type ty, bool isByRef ) {
89
104
std::string reductionName;
90
105
91
106
switch (intrinsicOp) {
@@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
108
123
break ;
109
124
}
110
125
111
- return getReductionName (reductionName, ty);
126
+ return getReductionName (reductionName, ty, isByRef );
112
127
}
113
128
114
129
mlir::Value
115
130
ReductionProcessor::getReductionInitValue (mlir::Location loc, mlir::Type type,
116
131
ReductionIdentifier redId,
117
132
fir::FirOpBuilder &builder) {
133
+ type = fir::unwrapRefType (type);
118
134
assert ((fir::isa_integer (type) || fir::isa_real (type) ||
119
135
type.isa <fir::LogicalType>()) &&
120
136
" only integer, logical and real types are currently supported" );
@@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
188
204
fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
189
205
mlir::Type type, mlir::Value op1, mlir::Value op2) {
190
206
mlir::Value reductionOp;
207
+ type = fir::unwrapRefType (type);
191
208
switch (redId) {
192
209
case ReductionIdentifier::MAX:
193
210
reductionOp =
@@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(
268
285
269
286
mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl (
270
287
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
271
- const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
288
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
289
+ bool isByRef) {
272
290
mlir::OpBuilder::InsertionGuard guard (builder);
273
291
mlir::ModuleOp module = builder.getModule ();
274
292
@@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
278
296
return decl;
279
297
280
298
mlir::OpBuilder modBuilder (module .getBodyRegion ());
299
+ mlir::Type valTy = fir::unwrapRefType (type);
300
+ if (!isByRef)
301
+ type = valTy;
281
302
282
303
decl = modBuilder.create <mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
283
304
type);
284
305
builder.createBlock (&decl.getInitializerRegion (),
285
306
decl.getInitializerRegion ().end (), {type}, {loc});
286
307
builder.setInsertionPointToEnd (&decl.getInitializerRegion ().back ());
308
+
287
309
mlir::Value init = getReductionInitValue (loc, type, redId, builder);
288
- builder.create <mlir::omp::YieldOp>(loc, init);
310
+ if (isByRef) {
311
+ mlir::Value alloca = builder.create <fir::AllocaOp>(loc, valTy);
312
+ builder.createStoreWithConvert (loc, init, alloca);
313
+ builder.create <mlir::omp::YieldOp>(loc, alloca);
314
+ } else {
315
+ builder.create <mlir::omp::YieldOp>(loc, init);
316
+ }
289
317
290
318
builder.createBlock (&decl.getReductionRegion (),
291
319
decl.getReductionRegion ().end (), {type, type},
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
294
322
builder.setInsertionPointToEnd (&decl.getReductionRegion ().back ());
295
323
mlir::Value op1 = decl.getReductionRegion ().front ().getArgument (0 );
296
324
mlir::Value op2 = decl.getReductionRegion ().front ().getArgument (1 );
325
+ mlir::Value outAddr = op1;
326
+
327
+ op1 = builder.loadIfRef (loc, op1);
328
+ op2 = builder.loadIfRef (loc, op2);
297
329
298
330
mlir::Value reductionOp =
299
331
createScalarCombiner (builder, loc, redId, type, op1, op2);
300
- builder.create <mlir::omp::YieldOp>(loc, reductionOp);
332
+ if (isByRef) {
333
+ builder.create <fir::StoreOp>(loc, reductionOp, outAddr);
334
+ builder.create <mlir::omp::YieldOp>(loc, outAddr);
335
+ } else {
336
+ builder.create <mlir::omp::YieldOp>(loc, reductionOp);
337
+ }
301
338
302
339
return decl;
303
340
}
304
341
342
+ bool ReductionProcessor::doReductionByRef (
343
+ const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
344
+ if (reductionVars.empty ())
345
+ return false ;
346
+ if (forceByrefReduction)
347
+ return true ;
348
+
349
+ for (mlir::Value reductionVar : reductionVars) {
350
+ if (auto declare =
351
+ mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp ()))
352
+ reductionVar = declare.getMemref ();
353
+
354
+ if (!fir::isa_trivial (fir::unwrapRefType (reductionVar.getType ())))
355
+ return true ;
356
+ }
357
+ return false ;
358
+ }
359
+
305
360
void ReductionProcessor::addReductionDecl (
306
361
mlir::Location currentLocation,
307
362
Fortran::lower::AbstractConverter &converter,
@@ -315,6 +370,24 @@ void ReductionProcessor::addReductionDecl(
315
370
const auto &redOperator{
316
371
std::get<Fortran::parser::OmpReductionOperator>(reduction.t )};
317
372
const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t )};
373
+
374
+ // initial pass to collect all recuction vars so we can figure out if this
375
+ // should happen byref
376
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
377
+ if (const auto *name{
378
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
379
+ if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
380
+ if (reductionSymbols)
381
+ reductionSymbols->push_back (symbol);
382
+ mlir::Value symVal = converter.getSymbolAddress (*symbol);
383
+ if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
384
+ symVal = declOp.getBase ();
385
+ reductionVars.push_back (symVal);
386
+ }
387
+ }
388
+ }
389
+ const bool isByRef = doReductionByRef (reductionVars);
390
+
318
391
if (const auto &redDefinedOp =
319
392
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u )) {
320
393
const auto &intrinsicOp{
@@ -338,23 +411,20 @@ void ReductionProcessor::addReductionDecl(
338
411
if (const auto *name{
339
412
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
340
413
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
341
- if (reductionSymbols)
342
- reductionSymbols->push_back (symbol);
343
414
mlir::Value symVal = converter.getSymbolAddress (*symbol);
344
415
if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
345
416
symVal = declOp.getBase ();
346
- mlir::Type redType =
347
- symVal.getType ().cast <fir::ReferenceType>().getEleTy ();
348
- reductionVars.push_back (symVal);
349
- if (redType.isa <fir::LogicalType>())
417
+ auto redType = symVal.getType ().cast <fir::ReferenceType>();
418
+ if (redType.getEleTy ().isa <fir::LogicalType>())
350
419
decl = createReductionDecl (
351
420
firOpBuilder,
352
- getReductionName (intrinsicOp, firOpBuilder.getI1Type ()), redId,
353
- redType, currentLocation);
354
- else if (redType.isIntOrIndexOrFloat ()) {
355
- decl = createReductionDecl (firOpBuilder,
356
- getReductionName (intrinsicOp, redType),
357
- redId, redType, currentLocation);
421
+ getReductionName (intrinsicOp, firOpBuilder.getI1Type (),
422
+ isByRef),
423
+ redId, redType, currentLocation, isByRef);
424
+ else if (redType.getEleTy ().isIntOrIndexOrFloat ()) {
425
+ decl = createReductionDecl (
426
+ firOpBuilder, getReductionName (intrinsicOp, redType, isByRef),
427
+ redId, redType, currentLocation, isByRef);
358
428
} else {
359
429
TODO (currentLocation, " Reduction of some types is not supported" );
360
430
}
@@ -374,21 +444,17 @@ void ReductionProcessor::addReductionDecl(
374
444
if (const auto *name{
375
445
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
376
446
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
377
- if (reductionSymbols)
378
- reductionSymbols->push_back (symbol);
379
447
mlir::Value symVal = converter.getSymbolAddress (*symbol);
380
448
if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
381
449
symVal = declOp.getBase ();
382
- mlir::Type redType =
383
- symVal.getType ().cast <fir::ReferenceType>().getEleTy ();
384
- reductionVars.push_back (symVal);
385
- assert (redType.isIntOrIndexOrFloat () &&
450
+ auto redType = symVal.getType ().cast <fir::ReferenceType>();
451
+ assert (redType.getEleTy ().isIntOrIndexOrFloat () &&
386
452
" Unsupported reduction type" );
387
453
decl = createReductionDecl (
388
454
firOpBuilder,
389
455
getReductionName (getRealName (*reductionIntrinsic).ToString (),
390
- redType),
391
- redId, redType, currentLocation);
456
+ redType, isByRef ),
457
+ redId, redType, currentLocation, isByRef );
392
458
reductionDeclSymbols.push_back (mlir::SymbolRefAttr::get (
393
459
firOpBuilder.getContext (), decl.getSymName ()));
394
460
}
0 commit comments