Skip to content

Commit 53cc229

Browse files
committed
[MLIR][OpenMP] Created callback function for adding the metadata of nontemporal and handled the translation in OpenMPToLLVMIRTranslation.cpp
1 parent a7a4859 commit 53cc229

File tree

3 files changed

+104
-94
lines changed

3 files changed

+104
-94
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,9 @@ class OpenMPIRBuilder {
12291229
void unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop, int32_t Factor,
12301230
CanonicalLoopInfo **UnrolledCLI);
12311231

1232+
using NonTemporalBodyGenCallbackTy =
1233+
function_ref<void(llvm::BasicBlock *BB, MDNode *NontemporalNode)>;
1234+
12321235
/// Add metadata to simd-ize a loop. If IfCond is not nullptr, the loop
12331236
/// is cloned. The metadata which prevents vectorization is added to
12341237
/// to the cloned loop. The cloned loop is executed when ifCond is evaluated
@@ -1242,10 +1245,15 @@ class OpenMPIRBuilder {
12421245
/// \param Order The enum to map order clause.
12431246
/// \param Simdlen The Simdlen length to apply to the simd loop.
12441247
/// \param Safelen The Safelen length to apply to the simd loop.
1245-
void applySimd(CanonicalLoopInfo *Loop,
1246-
MapVector<Value *, Value *> AlignedVars, Value *IfCond,
1247-
omp::OrderKind Order, ConstantInt *Simdlen,
1248-
ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
1248+
/// \param NontemporalCBFunc Call back function for nontemporal.
1249+
/// \param NontemporalVars Array of nontemporal vars.
1250+
void applySimd(
1251+
CanonicalLoopInfo *Loop, MapVector<Value *, Value *> AlignedVars,
1252+
Value *IfCond, omp::OrderKind Order, ConstantInt *Simdlen,
1253+
ConstantInt *Safelen,
1254+
NonTemporalBodyGenCallbackTy NontemporalCBFunc = [](BasicBlock *,
1255+
MDNode *) {},
1256+
ArrayRef<Value *> NontempralVars = {});
12491257

12501258
/// Generator for '#omp flush'
12511259
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 8 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5385,86 +5385,11 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
53855385
return 0;
53865386
}
53875387

5388-
static void appendNontemporalVars(BasicBlock *Block,
5389-
SmallVectorImpl<Value *> &NontemporalVars) {
5390-
for (Instruction &I : *Block) {
5391-
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
5392-
if (CI->getIntrinsicID() == Intrinsic::memcpy) {
5393-
llvm::Value *DestPtr = CI->getArgOperand(0);
5394-
llvm::Value *SrcPtr = CI->getArgOperand(1);
5395-
for (const llvm::Value *Var : NontemporalVars) {
5396-
if (Var == SrcPtr) {
5397-
NontemporalVars.push_back(DestPtr);
5398-
break;
5399-
}
5400-
}
5401-
}
5402-
}
5403-
}
5404-
}
5405-
5406-
/** Attach nontemporal metadata to the load/store instructions of nontemporal
5407-
* variables of \p Block
5408-
* Nontemporal variables may be a scalar, fixed size or allocatable
5409-
* or pointer array
5410-
*
5411-
* Example scenarios for nontemporal variables:
5412-
* Case 1: Scalar variable
5413-
* If the nontemporal variable is a scalar, it is allocated on stack.Load and
5414-
* store instructions directly access the alloca pointer of the scalar
5415-
* variable for fetching information about scalar variable or writing
5416-
* into the scalar variable. Mark those load and store instructions as
5417-
* non-temporal.
5418-
*
5419-
* Case 2: Fixed Size array
5420-
* If the nontemporal variable is a fixed-size array, it is allocated
5421-
* as a contiguous block of memory. It uses one GEP instruction, to compute the
5422-
* address of each individual array elements and perform load or store
5423-
* operation on it. Mark those load and store instructions as non-temporal.
5424-
*
5425-
* Case 3: Allocatable array
5426-
* For an allocatable array, which might involve runtime type descriptor,
5427-
* needs to navigate through descriptors using two or more GEP and load
5428-
* instructions to compute the address of each individual element in an array.
5429-
* Mark those load or store which access the individual array elements as
5430-
* non-temporal.
5431-
*/
5432-
static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
5433-
SmallVectorImpl<Value *> &NontemporalVars) {
5434-
appendNontemporalVars(Block, NontemporalVars);
5435-
for (Instruction &I : *Block) {
5436-
llvm::Value *mem_ptr = nullptr;
5437-
bool MetadataFlag = true;
5438-
if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
5439-
if (!(li->getType()->isPointerTy()))
5440-
mem_ptr = li->getPointerOperand();
5441-
} else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
5442-
mem_ptr = si->getPointerOperand();
5443-
if (mem_ptr) {
5444-
while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
5445-
if (llvm::GetElementPtrInst *gep =
5446-
dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
5447-
llvm::Type *sourceType = gep->getSourceElementType();
5448-
if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
5449-
!(gep->hasAllZeroIndices())) {
5450-
MetadataFlag = false;
5451-
break;
5452-
}
5453-
mem_ptr = gep->getPointerOperand();
5454-
} else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
5455-
mem_ptr = li->getPointerOperand();
5456-
}
5457-
if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
5458-
I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
5459-
}
5460-
}
5461-
}
5462-
5463-
void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5464-
MapVector<Value *, Value *> AlignedVars,
5465-
Value *IfCond, OrderKind Order,
5466-
ConstantInt *Simdlen, ConstantInt *Safelen,
5467-
ArrayRef<Value *> NontemporalVarsIn) {
5388+
void OpenMPIRBuilder::applySimd(
5389+
CanonicalLoopInfo *CanonicalLoop, MapVector<Value *, Value *> AlignedVars,
5390+
Value *IfCond, OrderKind Order, ConstantInt *Simdlen, ConstantInt *Safelen,
5391+
OpenMPIRBuilder::NonTemporalBodyGenCallbackTy NontemporalCBFunc,
5392+
ArrayRef<Value *> NontemporalVarsIn) {
54685393
LLVMContext &Ctx = Builder.getContext();
54695394

54705395
Function *F = CanonicalLoop->getFunction();
@@ -5562,12 +5487,12 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
55625487
}
55635488

55645489
addLoopMetadata(CanonicalLoop, LoopMDList);
5565-
SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
5490+
55665491
// Set nontemporal metadata to load and stores of nontemporal values
5567-
if (NontemporalVars.size()) {
5492+
if (NontemporalVarsIn.size()) {
55685493
MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
55695494
for (BasicBlock *BB : Reachable)
5570-
addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
5495+
NontemporalCBFunc(BB, NontemporalNode);
55715496
}
55725497
}
55735498

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2462,6 +2462,25 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
24622462
llvm_unreachable("Unknown ClauseOrderKind kind");
24632463
}
24642464

2465+
static void
2466+
appendNontemporalVars(llvm::BasicBlock *Block,
2467+
SmallVectorImpl<llvm::Value *> &NontemporalVars) {
2468+
for (llvm::Instruction &I : *Block) {
2469+
if (const llvm::CallInst *CI = dyn_cast<llvm::CallInst>(&I)) {
2470+
if (CI->getIntrinsicID() == llvm::Intrinsic::memcpy) {
2471+
llvm::Value *DestPtr = CI->getArgOperand(0);
2472+
llvm::Value *SrcPtr = CI->getArgOperand(1);
2473+
for (const llvm::Value *Var : NontemporalVars) {
2474+
if (Var == SrcPtr) {
2475+
NontemporalVars.push_back(DestPtr);
2476+
break;
2477+
}
2478+
}
2479+
}
2480+
}
2481+
}
2482+
}
2483+
24652484
/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
24662485
static LogicalResult
24672486
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2523,13 +2542,71 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
25232542
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
25242543
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
25252544

2526-
llvm::SmallVector<llvm::Value *> nontemporalVars;
2545+
llvm::SmallVector<llvm::Value *> nontemporalOrigVars;
25272546
mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
25282547
for (mlir::Value nontemporal : nontemporals) {
25292548
llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
2530-
nontemporalVars.push_back(nt);
2549+
nontemporalOrigVars.push_back(nt);
25312550
}
25322551

2552+
/** Call back function to attach nontemporal metadata to the load/store
2553+
* instructions of nontemporal variables of Block.
2554+
* Nontemporal variables may be a scalar, fixed size or allocatable
2555+
* or pointer array
2556+
*
2557+
* Example scenarios for nontemporal variables:
2558+
* Case 1: Scalar variable
2559+
* If the nontemporal variable is a scalar, it is allocated on stack.Load
2560+
* and store instructions directly access the alloca pointer of the scalar
2561+
* variable for fetching information about scalar variable or writing
2562+
* into the scalar variable. Mark those load and store instructions as
2563+
* non-temporal.
2564+
*
2565+
* Case 2: Fixed Size array
2566+
* If the nontemporal variable is a fixed-size array, it is allocated
2567+
* as a contiguous block of memory. It uses one GEP instruction, to compute
2568+
* the address of each individual array elements and perform load or store
2569+
* operation on it. Mark those load and store instructions as non-temporal.
2570+
*
2571+
* Case 3: Allocatable array
2572+
* For an allocatable array, which might involve runtime type descriptor,
2573+
* needs to navigate through descriptors using two or more GEP and load
2574+
* instructions to compute the address of each individual element in an array.
2575+
* Mark those load or store which access the individual array elements as
2576+
* non-temporal.
2577+
*/
2578+
auto addNonTemporalMetadataCB = [&](llvm::BasicBlock *Block,
2579+
llvm::MDNode *Nontemporal) {
2580+
SmallVector<llvm::Value *> NontemporalVars{nontemporalOrigVars};
2581+
appendNontemporalVars(Block, NontemporalVars);
2582+
for (llvm::Instruction &I : *Block) {
2583+
llvm::Value *mem_ptr = nullptr;
2584+
bool MetadataFlag = true;
2585+
if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
2586+
if (!(li->getType()->isPointerTy()))
2587+
mem_ptr = li->getPointerOperand();
2588+
} else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
2589+
mem_ptr = si->getPointerOperand();
2590+
if (mem_ptr) {
2591+
while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
2592+
if (llvm::GetElementPtrInst *gep =
2593+
dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
2594+
llvm::Type *sourceType = gep->getSourceElementType();
2595+
if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
2596+
!(gep->hasAllZeroIndices())) {
2597+
MetadataFlag = false;
2598+
break;
2599+
}
2600+
mem_ptr = gep->getPointerOperand();
2601+
} else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
2602+
mem_ptr = li->getPointerOperand();
2603+
}
2604+
if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
2605+
I.setMetadata(llvm::LLVMContext::MD_nontemporal, Nontemporal);
2606+
}
2607+
}
2608+
};
2609+
25332610
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
25342611
std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
25352612
mlir::OperandRange operands = simdOp.getAlignedVars();
@@ -2557,11 +2634,11 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
25572634

25582635
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
25592636
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2560-
ompBuilder->applySimd(loopInfo, alignedVars,
2561-
simdOp.getIfExpr()
2562-
? moduleTranslation.lookupValue(simdOp.getIfExpr())
2563-
: nullptr,
2564-
order, simdlen, safelen, nontemporalVars);
2637+
ompBuilder->applySimd(
2638+
loopInfo, alignedVars,
2639+
simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
2640+
: nullptr,
2641+
order, simdlen, safelen, addNonTemporalMetadataCB, nontemporalOrigVars);
25652642

25662643
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
25672644
llvmPrivateVars, privateDecls);

0 commit comments

Comments
 (0)