13
13
#include " ReductionProcessor.h"
14
14
15
15
#include " flang/Lower/AbstractConverter.h"
16
+ #include " flang/Optimizer/Builder/HLFIRTools.h"
16
17
#include " flang/Optimizer/Builder/Todo.h"
17
18
#include " flang/Optimizer/Dialect/FIRType.h"
18
19
#include " flang/Optimizer/HLFIR/HLFIROps.h"
@@ -90,10 +91,42 @@ std::string ReductionProcessor::getReductionName(llvm::StringRef name,
90
91
if (isByRef)
91
92
byrefAddition = " _byref" ;
92
93
93
- return (llvm::Twine (name) +
94
- (ty.isIntOrIndex () ? llvm::Twine (" _i_" ) : llvm::Twine (" _f_" )) +
95
- llvm::Twine (ty.getIntOrFloatBitWidth ()) + byrefAddition)
96
- .str ();
94
+ if (fir::isa_trivial (ty))
95
+ return (llvm::Twine (name) +
96
+ (ty.isIntOrIndex () ? llvm::Twine (" _i_" ) : llvm::Twine (" _f_" )) +
97
+ llvm::Twine (ty.getIntOrFloatBitWidth ()) + byrefAddition)
98
+ .str ();
99
+
100
+ // creates a name like reduction_i_64_box_ux4x3
101
+ if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
102
+ // TODO: support for allocatable boxes:
103
+ // !fir.box<!fir.heap<!fir.array<...>>>
104
+ fir::SequenceType seqTy = fir::unwrapRefType (boxTy.getEleTy ())
105
+ .dyn_cast_or_null <fir::SequenceType>();
106
+ if (!seqTy)
107
+ return {};
108
+
109
+ std::string prefix = getReductionName (
110
+ name, fir::unwrapSeqOrBoxedSeqType (ty), /* isByRef=*/ false );
111
+ if (prefix.empty ())
112
+ return {};
113
+ std::stringstream tyStr;
114
+ tyStr << prefix << " _box_" ;
115
+ bool first = true ;
116
+ for (std::int64_t extent : seqTy.getShape ()) {
117
+ if (first)
118
+ first = false ;
119
+ else
120
+ tyStr << " x" ;
121
+ if (extent == seqTy.getUnknownExtent ())
122
+ tyStr << ' u' ; // I'm not sure that '?' is safe in symbol names
123
+ else
124
+ tyStr << extent;
125
+ }
126
+ return (tyStr.str () + byrefAddition).str ();
127
+ }
128
+
129
+ return {};
97
130
}
98
131
99
132
std::string ReductionProcessor::getReductionName (
@@ -281,13 +314,158 @@ mlir::Value ReductionProcessor::createScalarCombiner(
281
314
return reductionOp;
282
315
}
283
316
317
+ // / Create reduction combiner region for reduction variables which are boxed
318
+ // / arrays
319
+ static void genBoxCombiner (fir::FirOpBuilder &builder, mlir::Location loc,
320
+ ReductionProcessor::ReductionIdentifier redId,
321
+ fir::BaseBoxType boxTy, mlir::Value lhs,
322
+ mlir::Value rhs) {
323
+ fir::SequenceType seqTy =
324
+ mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy ());
325
+ // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
326
+ if (!seqTy || seqTy.hasUnknownShape ())
327
+ TODO (loc, " Unsupported boxed type in OpenMP reduction" );
328
+
329
+ // load fir.ref<fir.box<...>>
330
+ mlir::Value lhsAddr = lhs;
331
+ lhs = builder.create <fir::LoadOp>(loc, lhs);
332
+ rhs = builder.create <fir::LoadOp>(loc, rhs);
333
+
334
+ const unsigned rank = seqTy.getDimension ();
335
+ llvm::SmallVector<mlir::Value> extents;
336
+ extents.reserve (rank);
337
+ llvm::SmallVector<mlir::Value> lbAndExtents;
338
+ lbAndExtents.reserve (rank * 2 );
339
+
340
+ // Get box lowerbounds and extents:
341
+ mlir::Type idxTy = builder.getIndexType ();
342
+ for (unsigned i = 0 ; i < rank; ++i) {
343
+ // TODO: ideally we want to hoist box reads out of the critical section.
344
+ // We could do this by having box dimensions in block arguments like
345
+ // OpenACC does
346
+ mlir::Value dim = builder.createIntegerConstant (loc, idxTy, i);
347
+ auto dimInfo =
348
+ builder.create <fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim);
349
+ extents.push_back (dimInfo.getExtent ());
350
+ lbAndExtents.push_back (dimInfo.getLowerBound ());
351
+ lbAndExtents.push_back (dimInfo.getExtent ());
352
+ }
353
+
354
+ auto shapeShiftTy = fir::ShapeShiftType::get (builder.getContext (), rank);
355
+ auto shapeShift =
356
+ builder.create <fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
357
+
358
+ // Iterate over array elements, applying the equivalent scalar reduction:
359
+
360
+ // A hlfir::elemental here gets inlined with a temporary so create the
361
+ // loop nest directly.
362
+ // This function already controls all of the code in this region so we
363
+ // know this won't miss any opportuinties for clever elemental inlining
364
+ hlfir::LoopNest nest =
365
+ hlfir::genLoopNest (loc, builder, extents, /* isUnordered=*/ true );
366
+ builder.setInsertionPointToStart (nest.innerLoop .getBody ());
367
+ mlir::Type refTy = fir::ReferenceType::get (seqTy.getEleTy ());
368
+ auto lhsEleAddr = builder.create <fir::ArrayCoorOp>(
369
+ loc, refTy, lhs, shapeShift, /* slice=*/ mlir::Value{},
370
+ nest.oneBasedIndices , /* typeparms=*/ mlir::ValueRange{});
371
+ auto rhsEleAddr = builder.create <fir::ArrayCoorOp>(
372
+ loc, refTy, rhs, shapeShift, /* slice=*/ mlir::Value{},
373
+ nest.oneBasedIndices , /* typeparms=*/ mlir::ValueRange{});
374
+ auto lhsEle = builder.create <fir::LoadOp>(loc, lhsEleAddr);
375
+ auto rhsEle = builder.create <fir::LoadOp>(loc, rhsEleAddr);
376
+ mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner (
377
+ builder, loc, redId, refTy, lhsEle, rhsEle);
378
+ builder.create <fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
379
+
380
+ builder.setInsertionPointAfter (nest.outerLoop );
381
+ builder.create <mlir::omp::YieldOp>(loc, lhsAddr);
382
+ }
383
+
384
+ // generate combiner region for reduction operations
385
+ static void genCombiner (fir::FirOpBuilder &builder, mlir::Location loc,
386
+ ReductionProcessor::ReductionIdentifier redId,
387
+ mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
388
+ bool isByRef) {
389
+ ty = fir::unwrapRefType (ty);
390
+
391
+ if (fir::isa_trivial (ty)) {
392
+ mlir::Value lhsLoaded = builder.loadIfRef (loc, lhs);
393
+ mlir::Value rhsLoaded = builder.loadIfRef (loc, rhs);
394
+
395
+ mlir::Value result = ReductionProcessor::createScalarCombiner (
396
+ builder, loc, redId, ty, lhsLoaded, rhsLoaded);
397
+ if (isByRef) {
398
+ builder.create <fir::StoreOp>(loc, result, lhs);
399
+ builder.create <mlir::omp::YieldOp>(loc, lhs);
400
+ } else {
401
+ builder.create <mlir::omp::YieldOp>(loc, result);
402
+ }
403
+ return ;
404
+ }
405
+ // all arrays should have been boxed
406
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
407
+ genBoxCombiner (builder, loc, redId, boxTy, lhs, rhs);
408
+ return ;
409
+ }
410
+
411
+ TODO (loc, " OpenMP genCombiner for unsupported reduction variable type" );
412
+ }
413
+
414
+ static mlir::Value
415
+ createReductionInitRegion (fir::FirOpBuilder &builder, mlir::Location loc,
416
+ const ReductionProcessor::ReductionIdentifier redId,
417
+ mlir::Type type, bool isByRef) {
418
+ mlir::Type ty = fir::unwrapRefType (type);
419
+ mlir::Value initValue = ReductionProcessor::getReductionInitValue (
420
+ loc, fir::unwrapSeqOrBoxedSeqType (ty), redId, builder);
421
+
422
+ if (fir::isa_trivial (ty)) {
423
+ if (isByRef) {
424
+ mlir::Value alloca = builder.create <fir::AllocaOp>(loc, ty);
425
+ builder.createStoreWithConvert (loc, initValue, alloca);
426
+ return alloca;
427
+ }
428
+ // by val
429
+ return initValue;
430
+ }
431
+
432
+ // all arrays are boxed
433
+ if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
434
+ assert (isByRef && " passing arrays by value is unsupported" );
435
+ // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
436
+ mlir::Type innerTy = fir::extractSequenceType (boxTy);
437
+ if (!mlir::isa<fir::SequenceType>(innerTy))
438
+ TODO (loc, " Unsupported boxed type for reduction" );
439
+ // Create the private copy from the initial fir.box:
440
+ hlfir::Entity source = hlfir::Entity{builder.getBlock ()->getArgument (0 )};
441
+
442
+ // TODO: if the whole reduction is nested inside of a loop, this alloca
443
+ // could lead to a stack overflow (the memory is only freed at the end of
444
+ // the stack frame). The reduction declare operation needs a deallocation
445
+ // region to undo the init region.
446
+ hlfir::Entity temp = createStackTempFromMold (loc, builder, source);
447
+
448
+ // Put the temporary inside of a box:
449
+ hlfir::Entity box = hlfir::genVariableBox (loc, builder, temp);
450
+ builder.create <hlfir::AssignOp>(loc, initValue, box);
451
+ mlir::Value boxAlloca = builder.create <fir::AllocaOp>(loc, ty);
452
+ builder.create <fir::StoreOp>(loc, box, boxAlloca);
453
+ return boxAlloca;
454
+ }
455
+
456
+ TODO (loc, " createReductionInitRegion for unsupported type" );
457
+ }
458
+
284
459
mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl (
285
460
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
286
461
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
287
462
bool isByRef) {
288
463
mlir::OpBuilder::InsertionGuard guard (builder);
289
464
mlir::ModuleOp module = builder.getModule ();
290
465
466
+ if (reductionOpName.empty ())
467
+ TODO (loc, " Reduction of some types is not supported" );
468
+
291
469
auto decl =
292
470
module .lookupSymbol <mlir::omp::ReductionDeclareOp>(reductionOpName);
293
471
if (decl)
@@ -304,14 +482,9 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
304
482
decl.getInitializerRegion ().end (), {type}, {loc});
305
483
builder.setInsertionPointToEnd (&decl.getInitializerRegion ().back ());
306
484
307
- mlir::Value init = getReductionInitValue (loc, type, redId, builder);
308
- if (isByRef) {
309
- mlir::Value alloca = builder.create <fir::AllocaOp>(loc, valTy);
310
- builder.createStoreWithConvert (loc, init, alloca);
311
- builder.create <mlir::omp::YieldOp>(loc, alloca);
312
- } else {
313
- builder.create <mlir::omp::YieldOp>(loc, init);
314
- }
485
+ mlir::Value init =
486
+ createReductionInitRegion (builder, loc, redId, type, isByRef);
487
+ builder.create <mlir::omp::YieldOp>(loc, init);
315
488
316
489
builder.createBlock (&decl.getReductionRegion (),
317
490
decl.getReductionRegion ().end (), {type, type},
@@ -320,19 +493,7 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
320
493
builder.setInsertionPointToEnd (&decl.getReductionRegion ().back ());
321
494
mlir::Value op1 = decl.getReductionRegion ().front ().getArgument (0 );
322
495
mlir::Value op2 = decl.getReductionRegion ().front ().getArgument (1 );
323
- mlir::Value outAddr = op1;
324
-
325
- op1 = builder.loadIfRef (loc, op1);
326
- op2 = builder.loadIfRef (loc, op2);
327
-
328
- mlir::Value reductionOp =
329
- createScalarCombiner (builder, loc, redId, type, op1, op2);
330
- if (isByRef) {
331
- builder.create <fir::StoreOp>(loc, reductionOp, outAddr);
332
- builder.create <mlir::omp::YieldOp>(loc, outAddr);
333
- } else {
334
- builder.create <mlir::omp::YieldOp>(loc, reductionOp);
335
- }
496
+ genCombiner (builder, loc, redId, type, op1, op2, isByRef);
336
497
337
498
return decl;
338
499
}
@@ -387,13 +548,33 @@ void ReductionProcessor::addReductionDecl(
387
548
388
549
// initial pass to collect all reduction vars so we can figure out if this
389
550
// should happen byref
551
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
390
552
for (const Object &object : objectList) {
391
553
const Fortran::semantics::Symbol *symbol = object.id ();
392
554
if (reductionSymbols)
393
555
reductionSymbols->push_back (symbol);
394
556
mlir::Value symVal = converter.getSymbolAddress (*symbol);
395
- if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
557
+ auto redType = mlir::cast<fir::ReferenceType>(symVal.getType ());
558
+
559
+ // all arrays must be boxed so that we have convenient access to all the
560
+ // information needed to iterate over the array
561
+ if (mlir::isa<fir::SequenceType>(redType.getEleTy ())) {
562
+ hlfir::Entity entity{symVal};
563
+ entity = genVariableBox (currentLocation, builder, entity);
564
+ mlir::Value box = entity.getBase ();
565
+
566
+ // Always pass the box by reference so that the OpenMP dialect
567
+ // verifiers don't need to know anything about fir.box
568
+ auto alloca =
569
+ builder.create <fir::AllocaOp>(currentLocation, box.getType ());
570
+ builder.create <fir::StoreOp>(currentLocation, box, alloca);
571
+
572
+ symVal = alloca;
573
+ redType = mlir::cast<fir::ReferenceType>(symVal.getType ());
574
+ } else if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>()) {
396
575
symVal = declOp.getBase ();
576
+ }
577
+
397
578
reductionVars.push_back (symVal);
398
579
}
399
580
const bool isByRef = doReductionByRef (reductionVars);
@@ -418,24 +599,17 @@ void ReductionProcessor::addReductionDecl(
418
599
break ;
419
600
}
420
601
421
- for (const Object &object : objectList) {
422
- const Fortran::semantics::Symbol *symbol = object.id ();
423
- mlir::Value symVal = converter.getSymbolAddress (*symbol);
424
- if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
425
- symVal = declOp.getBase ();
426
- auto redType = symVal.getType ().cast <fir::ReferenceType>();
602
+ for (mlir::Value symVal : reductionVars) {
603
+ auto redType = mlir::cast<fir::ReferenceType>(symVal.getType ());
427
604
if (redType.getEleTy ().isa <fir::LogicalType>())
428
605
decl = createReductionDecl (
429
606
firOpBuilder,
430
607
getReductionName (intrinsicOp, firOpBuilder.getI1Type (), isByRef),
431
608
redId, redType, currentLocation, isByRef);
432
- else if (redType. getEleTy (). isIntOrIndexOrFloat ()) {
609
+ else
433
610
decl = createReductionDecl (
434
611
firOpBuilder, getReductionName (intrinsicOp, redType, isByRef),
435
612
redId, redType, currentLocation, isByRef);
436
- } else {
437
- TODO (currentLocation, " Reduction of some types is not supported" );
438
- }
439
613
reductionDeclSymbols.push_back (mlir::SymbolRefAttr::get (
440
614
firOpBuilder.getContext (), decl.getSymName ()));
441
615
}
@@ -452,8 +626,8 @@ void ReductionProcessor::addReductionDecl(
452
626
if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
453
627
symVal = declOp.getBase ();
454
628
auto redType = symVal.getType ().cast <fir::ReferenceType>();
455
- assert ( redType.getEleTy ().isIntOrIndexOrFloat () &&
456
- " Unsupported reduction type" );
629
+ if (! redType.getEleTy ().isIntOrIndexOrFloat ())
630
+ TODO (currentLocation, " User Defined Reduction on non-trivial type" );
457
631
decl = createReductionDecl (
458
632
firOpBuilder,
459
633
getReductionName (getRealName (*reductionIntrinsic).ToString (),
0 commit comments