33
33
#include " llvm/Transforms/Utils/ModuleUtils.h"
34
34
35
35
#include < any>
36
+ #include < iterator>
36
37
#include < optional>
37
38
#include < utility>
38
39
@@ -878,36 +879,40 @@ static void collectReductionInfo(
878
879
}
879
880
880
881
// / handling of DeclareReductionOp's cleanup region
881
- static LogicalResult inlineReductionCleanup (
882
- llvm::SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
883
- llvm::ArrayRef<llvm::Value *> privateReductionVariables,
884
- LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder) {
885
- for (auto [i, reductionDecl] : llvm::enumerate (reductionDecls)) {
886
- Region &cleanupRegion = reductionDecl.getCleanupRegion ();
887
- if (cleanupRegion.empty ())
882
+ static LogicalResult
883
+ inlineOmpRegionCleanup (llvm::SmallVectorImpl<Region *> &cleanupRegions,
884
+ llvm::ArrayRef<llvm::Value *> privateVariables,
885
+ LLVM::ModuleTranslation &moduleTranslation,
886
+ llvm::IRBuilderBase &builder, StringRef regionName,
887
+ bool shouldLoadCleanupRegionArg = true ) {
888
+ for (auto [i, cleanupRegion] : llvm::enumerate (cleanupRegions)) {
889
+ if (cleanupRegion->empty ())
888
890
continue ;
889
891
890
892
// map the argument to the cleanup region
891
- Block &entry = cleanupRegion. front ();
893
+ Block &entry = cleanupRegion-> front ();
892
894
893
895
llvm::Instruction *potentialTerminator =
894
896
builder.GetInsertBlock ()->empty () ? nullptr
895
897
: &builder.GetInsertBlock ()->back ();
896
898
if (potentialTerminator && potentialTerminator->isTerminator ())
897
899
builder.SetInsertPoint (potentialTerminator);
898
- llvm::Value *reductionVar = builder.CreateLoad (
899
- moduleTranslation.convertType (entry.getArgument (0 ).getType ()),
900
- privateReductionVariables[i]);
900
+ llvm::Value *prviateVarValue =
901
+ shouldLoadCleanupRegionArg
902
+ ? builder.CreateLoad (
903
+ moduleTranslation.convertType (entry.getArgument (0 ).getType ()),
904
+ privateVariables[i])
905
+ : privateVariables[i];
901
906
902
- moduleTranslation.mapValue (entry.getArgument (0 ), reductionVar );
907
+ moduleTranslation.mapValue (entry.getArgument (0 ), prviateVarValue );
903
908
904
- if (failed (inlineConvertOmpRegions (cleanupRegion, " omp.reduction.cleanup " ,
905
- builder, moduleTranslation)))
909
+ if (failed (inlineConvertOmpRegions (* cleanupRegion, regionName, builder ,
910
+ moduleTranslation)))
906
911
return failure ();
907
912
908
913
// clear block argument mapping in case it needs to be re-created with a
909
914
// different source for another use of the same reduction decl
910
- moduleTranslation.forgetMapping (cleanupRegion);
915
+ moduleTranslation.forgetMapping (* cleanupRegion);
911
916
}
912
917
return success ();
913
918
}
@@ -1110,8 +1115,14 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
1110
1115
builder.restoreIP (nextInsertionPoint);
1111
1116
1112
1117
// after the workshare loop, deallocate private reduction variables
1113
- return inlineReductionCleanup (reductionDecls, privateReductionVariables,
1114
- moduleTranslation, builder);
1118
+ SmallVector<Region *> reductionRegions;
1119
+ llvm::transform (reductionDecls, std::back_inserter (reductionRegions),
1120
+ [](omp::DeclareReductionOp reductionDecl) {
1121
+ return &reductionDecl.getCleanupRegion ();
1122
+ });
1123
+ return inlineOmpRegionCleanup (reductionRegions, privateReductionVariables,
1124
+ moduleTranslation, builder,
1125
+ " omp.reduction.cleanup" );
1115
1126
}
1116
1127
1117
1128
// / A RAII class that on construction replaces the region arguments of the
@@ -1267,6 +1278,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1267
1278
}
1268
1279
};
1269
1280
1281
+ SmallVector<omp::PrivateClauseOp> privatizerClones;
1282
+ SmallVector<llvm::Value *> privateVariables;
1283
+
1270
1284
// TODO: Perform appropriate actions according to the data-sharing
1271
1285
// attribute (shared, private, firstprivate, ...) of variables.
1272
1286
// Currently shared and private are supported.
@@ -1356,12 +1370,17 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1356
1370
opInst.emitError (" failed to inline `alloc` region of an `omp.private` "
1357
1371
" op in the parallel region" );
1358
1372
bodyGenStatus = failure ();
1373
+ privatizerClone.erase ();
1359
1374
} else {
1360
1375
assert (yieldedValues.size () == 1 );
1361
1376
replacementValue = yieldedValues.front ();
1377
+
1378
+ // Keep the LLVM replacement value and the op clone in case we need to
1379
+ // emit cleanup (i.e. deallocation) logic.
1380
+ privateVariables.push_back (replacementValue);
1381
+ privatizerClones.push_back (privatizerClone);
1362
1382
}
1363
1383
1364
- privatizerClone.erase ();
1365
1384
builder.restoreIP (oldIP);
1366
1385
}
1367
1386
@@ -1376,8 +1395,25 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1376
1395
1377
1396
// if the reduction has a cleanup region, inline it here to finalize the
1378
1397
// reduction variables
1379
- if (failed (inlineReductionCleanup (reductionDecls, privateReductionVariables,
1380
- moduleTranslation, builder)))
1398
+ SmallVector<Region *> reductionCleanupRegions;
1399
+ llvm::transform (reductionDecls, std::back_inserter (reductionCleanupRegions),
1400
+ [](omp::DeclareReductionOp reductionDecl) {
1401
+ return &reductionDecl.getCleanupRegion ();
1402
+ });
1403
+ if (failed (inlineOmpRegionCleanup (
1404
+ reductionCleanupRegions, privateReductionVariables,
1405
+ moduleTranslation, builder, " omp.reduction.cleanup" )))
1406
+ bodyGenStatus = failure ();
1407
+
1408
+ SmallVector<Region *> privateCleanupRegions;
1409
+ llvm::transform (privatizerClones, std::back_inserter (privateCleanupRegions),
1410
+ [](omp::PrivateClauseOp privatizer) {
1411
+ return &privatizer.getDeallocRegion ();
1412
+ });
1413
+
1414
+ if (failed (inlineOmpRegionCleanup (
1415
+ privateCleanupRegions, privateVariables, moduleTranslation, builder,
1416
+ " omp.private.dealloc" , /* shouldLoadCleanupRegionArg=*/ false )))
1381
1417
bodyGenStatus = failure ();
1382
1418
1383
1419
builder.restoreIP (oldIP);
@@ -1403,6 +1439,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1403
1439
ompBuilder->createParallel (ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1404
1440
ifCond, numThreads, pbKind, isCancellable));
1405
1441
1442
+ for (mlir::omp::PrivateClauseOp privatizerClone : privatizerClones)
1443
+ privatizerClone.erase ();
1444
+
1406
1445
return bodyGenStatus;
1407
1446
}
1408
1447
0 commit comments