Skip to content

Commit a7a4859

Browse files
committed
[MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for SIMD directive
1 parent c7dbf20 commit a7a4859

File tree

5 files changed

+190
-21
lines changed

5 files changed

+190
-21
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1245,7 +1245,7 @@ class OpenMPIRBuilder {
12451245
void applySimd(CanonicalLoopInfo *Loop,
12461246
MapVector<Value *, Value *> AlignedVars, Value *IfCond,
12471247
omp::OrderKind Order, ConstantInt *Simdlen,
1248-
ConstantInt *Safelen);
1248+
ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
12491249

12501250
/// Generator for '#omp flush'
12511251
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5385,10 +5385,86 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
53855385
return 0;
53865386
}
53875387

5388+
static void appendNontemporalVars(BasicBlock *Block,
5389+
SmallVectorImpl<Value *> &NontemporalVars) {
5390+
for (Instruction &I : *Block) {
5391+
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
5392+
if (CI->getIntrinsicID() == Intrinsic::memcpy) {
5393+
llvm::Value *DestPtr = CI->getArgOperand(0);
5394+
llvm::Value *SrcPtr = CI->getArgOperand(1);
5395+
for (const llvm::Value *Var : NontemporalVars) {
5396+
if (Var == SrcPtr) {
5397+
NontemporalVars.push_back(DestPtr);
5398+
break;
5399+
}
5400+
}
5401+
}
5402+
}
5403+
}
5404+
}
5405+
5406+
/** Attach nontemporal metadata to the load/store instructions of nontemporal
5407+
* variables of \p Block
5408+
* Nontemporal variables may be a scalar, fixed size or allocatable
5409+
* or pointer array
5410+
*
5411+
* Example scenarios for nontemporal variables:
5412+
* Case 1: Scalar variable
5413+
* If the nontemporal variable is a scalar, it is allocated on stack.Load and
5414+
* store instructions directly access the alloca pointer of the scalar
5415+
* variable for fetching information about scalar variable or writing
5416+
* into the scalar variable. Mark those load and store instructions as
5417+
* non-temporal.
5418+
*
5419+
* Case 2: Fixed Size array
5420+
* If the nontemporal variable is a fixed-size array, it is allocated
5421+
* as a contiguous block of memory. It uses one GEP instruction, to compute the
5422+
* address of each individual array elements and perform load or store
5423+
* operation on it. Mark those load and store instructions as non-temporal.
5424+
*
5425+
* Case 3: Allocatable array
5426+
* For an allocatable array, which might involve runtime type descriptor,
5427+
* needs to navigate through descriptors using two or more GEP and load
5428+
* instructions to compute the address of each individual element in an array.
5429+
* Mark those load or store which access the individual array elements as
5430+
* non-temporal.
5431+
*/
5432+
static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
5433+
SmallVectorImpl<Value *> &NontemporalVars) {
5434+
appendNontemporalVars(Block, NontemporalVars);
5435+
for (Instruction &I : *Block) {
5436+
llvm::Value *mem_ptr = nullptr;
5437+
bool MetadataFlag = true;
5438+
if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
5439+
if (!(li->getType()->isPointerTy()))
5440+
mem_ptr = li->getPointerOperand();
5441+
} else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
5442+
mem_ptr = si->getPointerOperand();
5443+
if (mem_ptr) {
5444+
while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
5445+
if (llvm::GetElementPtrInst *gep =
5446+
dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
5447+
llvm::Type *sourceType = gep->getSourceElementType();
5448+
if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
5449+
!(gep->hasAllZeroIndices())) {
5450+
MetadataFlag = false;
5451+
break;
5452+
}
5453+
mem_ptr = gep->getPointerOperand();
5454+
} else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
5455+
mem_ptr = li->getPointerOperand();
5456+
}
5457+
if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
5458+
I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
5459+
}
5460+
}
5461+
}
5462+
53885463
void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
53895464
MapVector<Value *, Value *> AlignedVars,
53905465
Value *IfCond, OrderKind Order,
5391-
ConstantInt *Simdlen, ConstantInt *Safelen) {
5466+
ConstantInt *Simdlen, ConstantInt *Safelen,
5467+
ArrayRef<Value *> NontemporalVarsIn) {
53925468
LLVMContext &Ctx = Builder.getContext();
53935469

53945470
Function *F = CanonicalLoop->getFunction();
@@ -5486,6 +5562,13 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
54865562
}
54875563

54885564
addLoopMetadata(CanonicalLoop, LoopMDList);
5565+
SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
5566+
// Set nontemporal metadata to load and stores of nontemporal values
5567+
if (NontemporalVars.size()) {
5568+
MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
5569+
for (BasicBlock *BB : Reachable)
5570+
addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
5571+
}
54895572
}
54905573

54915574
/// Create the TargetMachine object to query the backend for optimization

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
189189
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
190190
result = todo("linear");
191191
};
192-
auto checkNontemporal = [&todo](auto op, LogicalResult &result) {
193-
if (!op.getNontemporalVars().empty())
194-
result = todo("nontemporal");
195-
};
196192
auto checkNowait = [&todo](auto op, LogicalResult &result) {
197193
if (op.getNowait())
198194
result = todo("nowait");
@@ -300,7 +296,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
300296
})
301297
.Case([&](omp::SimdOp op) {
302298
checkLinear(op, result);
303-
checkNontemporal(op, result);
304299
checkReduction(op, result);
305300
})
306301
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -2527,6 +2522,14 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
25272522

25282523
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
25292524
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
2525+
2526+
llvm::SmallVector<llvm::Value *> nontemporalVars;
2527+
mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
2528+
for (mlir::Value nontemporal : nontemporals) {
2529+
llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
2530+
nontemporalVars.push_back(nt);
2531+
}
2532+
25302533
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
25312534
std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
25322535
mlir::OperandRange operands = simdOp.getAlignedVars();
@@ -2558,7 +2561,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
25582561
simdOp.getIfExpr()
25592562
? moduleTranslation.lookupValue(simdOp.getIfExpr())
25602563
: nullptr,
2561-
order, simdlen, safelen);
2564+
order, simdlen, safelen, nontemporalVars);
25622565

25632566
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
25642567
llvmPrivateVars, privateDecls);
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// -----
4+
// CHECK-LABEL: @simd_nontemporal
5+
llvm.func @simd_nontemporal() {
6+
%0 = llvm.mlir.constant(10 : i64) : i64
7+
%1 = llvm.mlir.constant(1 : i64) : i64
8+
%2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
9+
%3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
10+
//CHECK: %[[A_ADDR:.*]] = alloca i64, i64 1, align 8
11+
//CHECK: %[[B_ADDR:.*]] = alloca i64, i64 1, align 8
12+
//CHECK: %[[B:.*]] = load i64, ptr %[[B_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
13+
//CHECK: store i64 %[[B]], ptr %[[A_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
14+
omp.simd nontemporal(%2, %3 : !llvm.ptr, !llvm.ptr) {
15+
omp.loop_nest (%arg0) : i64 = (%1) to (%0) inclusive step (%1) {
16+
%4 = llvm.load %3 : !llvm.ptr -> i64
17+
llvm.store %4, %2 : i64, !llvm.ptr
18+
omp.yield
19+
}
20+
}
21+
llvm.return
22+
}
23+
24+
// -----
25+
26+
//CHECK-LABEL: define void @_QPtest(ptr %0, ptr %1) {
27+
llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fir.bindc_name = "a"}) {
28+
%0 = llvm.mlir.constant(1 : i32) : i32
29+
%1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
30+
%2 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
31+
%3 = llvm.mlir.constant(1 : i64) : i64
32+
%4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
33+
%6 = llvm.load %arg0 : !llvm.ptr -> i32
34+
//CHECK: %[[A_VAL1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
35+
//CHECK: %[[A_VAL2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
36+
omp.simd nontemporal(%arg1 : !llvm.ptr) {
37+
omp.loop_nest (%arg2) : i32 = (%0) to (%6) inclusive step (%0) {
38+
llvm.store %arg2, %4 : i32, !llvm.ptr
39+
//CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL2]], ptr %1, i32 48, i1 false)
40+
%7 = llvm.mlir.constant(48 : i32) : i32
41+
"llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
42+
%8 = llvm.load %4 : !llvm.ptr -> i32
43+
%9 = llvm.sext %8 : i32 to i64
44+
%10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
45+
%11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
46+
%12 = llvm.mlir.constant(0 : index) : i64
47+
%13 = llvm.getelementptr %2[0, 7, %12, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
48+
%14 = llvm.load %13 : !llvm.ptr -> i64
49+
%15 = llvm.getelementptr %2[0, 7, %12, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
50+
%16 = llvm.load %15 : !llvm.ptr -> i64
51+
%17 = llvm.getelementptr %2[0, 7, %12, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
52+
%18 = llvm.load %17 : !llvm.ptr -> i64
53+
%19 = llvm.mlir.constant(0 : i64) : i64
54+
%20 = llvm.sub %9, %14 overflow<nsw> : i64
55+
%21 = llvm.mul %20, %3 overflow<nsw> : i64
56+
%22 = llvm.mul %21, %3 overflow<nsw> : i64
57+
%23 = llvm.add %22,%19 overflow<nsw> : i64
58+
%24 = llvm.mul %3, %16 overflow<nsw> : i64
59+
//CHECK: %[[VAL1:.*]] = getelementptr float, ptr {{.*}}, i64 %{{.*}}
60+
//CHECK: %[[LOAD_A:.*]] = load float, ptr %[[VAL1]], align 4, !nontemporal
61+
//CHECK: %[[RES:.*]] = fadd contract float %[[LOAD_A]], 2.000000e+01
62+
%25 = llvm.getelementptr %11[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
63+
%26 = llvm.load %25 : !llvm.ptr -> f32
64+
%27 = llvm.mlir.constant(2.000000e+01 : f32) : f32
65+
%28 = llvm.fadd %26, %27 {fastmathFlags = #llvm.fastmath<contract>} : f32
66+
//CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL1]], ptr %1, i32 48, i1 false)
67+
%29 = llvm.mlir.constant(48 : i32) : i32
68+
"llvm.intr.memcpy"(%1, %arg1, %29) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
69+
%30 = llvm.load %4 : !llvm.ptr -> i32
70+
%31 = llvm.sext %30 : i32 to i64
71+
%32 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
72+
%33 = llvm.load %32 : !llvm.ptr -> !llvm.ptr
73+
%34 = llvm.mlir.constant(0 : index) : i64
74+
%35 = llvm.getelementptr %1[0, 7, %34, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
75+
%36 = llvm.load %35 : !llvm.ptr -> i64
76+
%37 = llvm.getelementptr %1[0, 7, %34, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
77+
%38 = llvm.load %37 : !llvm.ptr -> i64
78+
%39 = llvm.getelementptr %1[0, 7, %34, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
79+
%40 = llvm.load %39 : !llvm.ptr -> i64
80+
%41 = llvm.sub %31, %36 overflow<nsw> : i64
81+
%42 = llvm.mul %41, %3 overflow<nsw> : i64
82+
%43 = llvm.mul %42, %3 overflow<nsw> : i64
83+
%44 = llvm.add %43,%19 overflow<nsw> : i64
84+
%45 = llvm.mul %3, %38 overflow<nsw> : i64
85+
//CHECK: %[[VAL2:.*]] = getelementptr float, ptr %{{.*}}, i64 %{{.*}}
86+
//CHECK: store float %[[RES]], ptr %[[VAL2]], align 4, !nontemporal
87+
%46 = llvm.getelementptr %33[%44] : (!llvm.ptr, i64) -> !llvm.ptr, f32
88+
llvm.store %28, %46 : f32, !llvm.ptr
89+
omp.yield
90+
}
91+
}
92+
llvm.return
93+
}
94+
// -----
95+
96+

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -184,19 +184,6 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
184184

185185
// -----
186186

187-
llvm.func @simd_nontemporal(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
188-
// expected-error@below {{not yet implemented: Unhandled clause nontemporal in omp.simd operation}}
189-
// expected-error@below {{LLVM Translation failed for operation: omp.simd}}
190-
omp.simd nontemporal(%x : !llvm.ptr) {
191-
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
192-
omp.yield
193-
}
194-
}
195-
llvm.return
196-
}
197-
198-
// -----
199-
200187
omp.declare_reduction @add_f32 : f32
201188
init {
202189
^bb0(%arg: f32):

0 commit comments

Comments
 (0)