@@ -2462,6 +2462,25 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
2462
2462
llvm_unreachable (" Unknown ClauseOrderKind kind" );
2463
2463
}
2464
2464
2465
+ static void
2466
+ appendNontemporalVars (llvm::BasicBlock *Block,
2467
+ SmallVectorImpl<llvm::Value *> &NontemporalVars) {
2468
+ for (llvm::Instruction &I : *Block) {
2469
+ if (const llvm::CallInst *CI = dyn_cast<llvm::CallInst>(&I)) {
2470
+ if (CI->getIntrinsicID () == llvm::Intrinsic::memcpy) {
2471
+ llvm::Value *DestPtr = CI->getArgOperand (0 );
2472
+ llvm::Value *SrcPtr = CI->getArgOperand (1 );
2473
+ for (const llvm::Value *Var : NontemporalVars) {
2474
+ if (Var == SrcPtr) {
2475
+ NontemporalVars.push_back (DestPtr);
2476
+ break ;
2477
+ }
2478
+ }
2479
+ }
2480
+ }
2481
+ }
2482
+ }
2483
+
2465
2484
// / Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
2466
2485
static LogicalResult
2467
2486
convertOmpSimd (Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2523,13 +2542,71 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2523
2542
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2524
2543
llvm::omp::OrderKind order = convertOrderKind (simdOp.getOrder ());
2525
2544
2526
- llvm::SmallVector<llvm::Value *> nontemporalVars ;
2545
+ llvm::SmallVector<llvm::Value *> nontemporalOrigVars ;
2527
2546
mlir::OperandRange nontemporals = simdOp.getNontemporalVars ();
2528
2547
for (mlir::Value nontemporal : nontemporals) {
2529
2548
llvm::Value *nt = moduleTranslation.lookupValue (nontemporal);
2530
- nontemporalVars .push_back (nt);
2549
+ nontemporalOrigVars .push_back (nt);
2531
2550
}
2532
2551
2552
+ /* * Call back function to attach nontemporal metadata to the load/store
2553
+ * instructions of nontemporal variables of Block.
2554
+ * Nontemporal variables may be a scalar, fixed size or allocatable
2555
+ * or pointer array
2556
+ *
2557
+ * Example scenarios for nontemporal variables:
2558
+ * Case 1: Scalar variable
2559
+ * If the nontemporal variable is a scalar, it is allocated on stack.Load
2560
+ * and store instructions directly access the alloca pointer of the scalar
2561
+ * variable for fetching information about scalar variable or writing
2562
+ * into the scalar variable. Mark those load and store instructions as
2563
+ * non-temporal.
2564
+ *
2565
+ * Case 2: Fixed Size array
2566
+ * If the nontemporal variable is a fixed-size array, it is allocated
2567
+ * as a contiguous block of memory. It uses one GEP instruction, to compute
2568
+ * the address of each individual array elements and perform load or store
2569
+ * operation on it. Mark those load and store instructions as non-temporal.
2570
+ *
2571
+ * Case 3: Allocatable array
2572
+ * For an allocatable array, which might involve runtime type descriptor,
2573
+ * needs to navigate through descriptors using two or more GEP and load
2574
+ * instructions to compute the address of each individual element in an array.
2575
+ * Mark those load or store which access the individual array elements as
2576
+ * non-temporal.
2577
+ */
2578
+ auto addNonTemporalMetadataCB = [&](llvm::BasicBlock *Block,
2579
+ llvm::MDNode *Nontemporal) {
2580
+ SmallVector<llvm::Value *> NontemporalVars{nontemporalOrigVars};
2581
+ appendNontemporalVars (Block, NontemporalVars);
2582
+ for (llvm::Instruction &I : *Block) {
2583
+ llvm::Value *mem_ptr = nullptr ;
2584
+ bool MetadataFlag = true ;
2585
+ if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
2586
+ if (!(li->getType ()->isPointerTy ()))
2587
+ mem_ptr = li->getPointerOperand ();
2588
+ } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
2589
+ mem_ptr = si->getPointerOperand ();
2590
+ if (mem_ptr) {
2591
+ while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
2592
+ if (llvm::GetElementPtrInst *gep =
2593
+ dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
2594
+ llvm::Type *sourceType = gep->getSourceElementType ();
2595
+ if (sourceType->isStructTy () && gep->getNumIndices () >= 2 &&
2596
+ !(gep->hasAllZeroIndices ())) {
2597
+ MetadataFlag = false ;
2598
+ break ;
2599
+ }
2600
+ mem_ptr = gep->getPointerOperand ();
2601
+ } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
2602
+ mem_ptr = li->getPointerOperand ();
2603
+ }
2604
+ if (MetadataFlag && is_contained (NontemporalVars, mem_ptr))
2605
+ I.setMetadata (llvm::LLVMContext::MD_nontemporal, Nontemporal);
2606
+ }
2607
+ }
2608
+ };
2609
+
2533
2610
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock ();
2534
2611
std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments ();
2535
2612
mlir::OperandRange operands = simdOp.getAlignedVars ();
@@ -2557,11 +2634,11 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2557
2634
2558
2635
builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2559
2636
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2560
- ompBuilder->applySimd (loopInfo, alignedVars,
2561
- simdOp. getIfExpr ()
2562
- ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2563
- : nullptr ,
2564
- order, simdlen, safelen, nontemporalVars );
2637
+ ompBuilder->applySimd (
2638
+ loopInfo, alignedVars,
2639
+ simdOp. getIfExpr () ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2640
+ : nullptr ,
2641
+ order, simdlen, safelen, addNonTemporalMetadataCB, nontemporalOrigVars );
2565
2642
2566
2643
return cleanupPrivateVars (builder, moduleTranslation, simdOp.getLoc (),
2567
2644
llvmPrivateVars, privateDecls);
0 commit comments