Skip to content

Commit 792b974

Browse files
committed
[MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for SIMD directive
1 parent 72aefbb commit 792b974

File tree

5 files changed

+192
-21
lines changed

5 files changed

+192
-21
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ class OpenMPIRBuilder {
12201220
void applySimd(CanonicalLoopInfo *Loop,
12211221
MapVector<Value *, Value *> AlignedVars, Value *IfCond,
12221222
omp::OrderKind Order, ConstantInt *Simdlen,
1223-
ConstantInt *Safelen);
1223+
ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
12241224

12251225
/// Generator for '#omp flush'
12261226
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5265,10 +5265,87 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
52655265
return 0;
52665266
}
52675267

5268+
static void appendNontemporalVars(BasicBlock *Block,
5269+
SmallVectorImpl<Value *> &NontemporalVars) {
5270+
for (Instruction &I : *Block) {
5271+
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
5272+
if (CI->getIntrinsicID() == Intrinsic::memcpy) {
5273+
llvm::Value *DestPtr = CI->getArgOperand(0);
5274+
llvm::Value *SrcPtr = CI->getArgOperand(1);
5275+
for (const llvm::Value *Var : NontemporalVars) {
5276+
if (Var == SrcPtr) {
5277+
NontemporalVars.push_back(DestPtr);
5278+
break;
5279+
}
5280+
}
5281+
}
5282+
}
5283+
}
5284+
}
5285+
5286+
/** Attach nontemporal metadata to the load/store instructions of nontemporal
5287+
* variables of \p Block
5288+
* Nontemporal variables may be a scalar, fixed size or allocatable
5289+
* or pointer array
5290+
*
5291+
* !$omp simd nontemporal(a,b) ;; where a is scalar
5292+
* %a = alloca i32, i64 1 ;; (allocate a)
5293+
* %1 = load i32, ptr %a ;; (mark LOAD as nontemporal)
5294+
* store i32 11, ptr %1 ;; (mark STORE as nontemporal)
5295+
*
5296+
* !$omp simd nontemporal(a) ;; where a is an fixed size array
5297+
* %a = alloca [20 x i32], i64 1 ;; (allocate a)
5298+
* %2 = getelementptr i32, ptr %a, i64 %1 ;; (compute the address of arr ele)
5299+
* %3 = load i32, ptr %2 ;; (mark LOAD as nontemporal)
5300+
*
5301+
* !$omp simd nontemporal(a), ;; where a is an allocatable
5302+
* %struct.a = { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }
5303+
* %a = alloca %struct.a ;; (allocate a)
5304+
* %a_copy = alloca %struct.a
5305+
* call void @llvm.memcpy.p0.p0.i32(ptr %a_copy, ptr %a, i32 48, i1 false)
5306+
* %1 = getelementptr %struct.a, ptr %a_copy, i32 0, i32 0
5307+
* %2 = load ptr, ptr %1, align 8
5308+
* %3 = getelementptr i32, ptr %2, i64 %52
5309+
* %4 = load i32, ptr %3 ;; (mark LOAD as nontemporal)
5310+
*
5311+
* It works the same way for store
5312+
*/
5313+
static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
5314+
SmallVectorImpl<Value *> &NontemporalVars) {
5315+
appendNontemporalVars(Block, NontemporalVars);
5316+
for (Instruction &I : *Block) {
5317+
llvm::Value *mem_ptr = nullptr;
5318+
bool MetadataFlag = true;
5319+
if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
5320+
if (!(li->getType()->isPointerTy()))
5321+
mem_ptr = li->getPointerOperand();
5322+
} else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
5323+
mem_ptr = si->getPointerOperand();
5324+
if (mem_ptr) {
5325+
while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
5326+
if (llvm::GetElementPtrInst *gep =
5327+
dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
5328+
llvm::Type *sourceType = gep->getSourceElementType();
5329+
if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
5330+
!(gep->hasAllZeroIndices())) {
5331+
MetadataFlag = false;
5332+
break;
5333+
}
5334+
mem_ptr = gep->getPointerOperand();
5335+
} else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
5336+
mem_ptr = li->getPointerOperand();
5337+
}
5338+
if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
5339+
I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
5340+
}
5341+
}
5342+
}
5343+
52685344
void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
52695345
MapVector<Value *, Value *> AlignedVars,
52705346
Value *IfCond, OrderKind Order,
5271-
ConstantInt *Simdlen, ConstantInt *Safelen) {
5347+
ConstantInt *Simdlen, ConstantInt *Safelen,
5348+
ArrayRef<Value *> NontemporalVarsIn) {
52725349
LLVMContext &Ctx = Builder.getContext();
52735350

52745351
Function *F = CanonicalLoop->getFunction();
@@ -5365,6 +5442,13 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
53655442
}
53665443

53675444
addLoopMetadata(CanonicalLoop, LoopMDList);
5445+
SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
5446+
// Set nontemporal metadata to load and stores of nontemporal values
5447+
if (NontemporalVars.size()) {
5448+
MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
5449+
for (BasicBlock *BB : Reachable)
5450+
addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
5451+
}
53685452
}
53695453

53705454
/// Create the TargetMachine object to query the backend for optimization

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
191191
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
192192
result = todo("linear");
193193
};
194-
auto checkNontemporal = [&todo](auto op, LogicalResult &result) {
195-
if (!op.getNontemporalVars().empty())
196-
result = todo("nontemporal");
197-
};
194+
198195
auto checkNowait = [&todo](auto op, LogicalResult &result) {
199196
if (op.getNowait())
200197
result = todo("nowait");
@@ -274,7 +271,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
274271
.Case([&](omp::SimdOp op) {
275272
checkAligned(op, result);
276273
checkLinear(op, result);
277-
checkNontemporal(op, result);
278274
checkPrivate(op, result);
279275
checkReduction(op, result);
280276
})
@@ -2230,11 +2226,19 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
22302226

22312227
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
22322228
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
2229+
2230+
llvm::SmallVector<llvm::Value *> nontemporalVars;
2231+
mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
2232+
for (mlir::Value nontemporal : nontemporals) {
2233+
llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
2234+
nontemporalVars.push_back(nt);
2235+
}
2236+
22332237
ompBuilder->applySimd(loopInfo, alignedVars,
22342238
simdOp.getIfExpr()
22352239
? moduleTranslation.lookupValue(simdOp.getIfExpr())
22362240
: nullptr,
2237-
order, simdlen, safelen);
2241+
order, simdlen, safelen, nontemporalVars);
22382242

22392243
builder.restoreIP(afterIP);
22402244
return success();
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// -----
4+
// CHECK-LABEL: @simd_nontemporal
5+
llvm.func @simd_nontemporal() {
6+
%0 = llvm.mlir.constant(10 : i64) : i64
7+
%1 = llvm.mlir.constant(1 : i64) : i64
8+
%2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
9+
%3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
10+
//CHECK: %[[A_ADDR:.*]] = alloca i64, i64 1, align 8
11+
//CHECK: %[[B_ADDR:.*]] = alloca i64, i64 1, align 8
12+
//CHECK: %[[B:.*]] = load i64, ptr %[[B_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
13+
//CHECK: store i64 %[[B]], ptr %[[A_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
14+
omp.simd nontemporal(%2, %3 : !llvm.ptr, !llvm.ptr) {
15+
omp.loop_nest (%arg0) : i64 = (%1) to (%0) inclusive step (%1) {
16+
%4 = llvm.load %3 : !llvm.ptr -> i64
17+
llvm.store %4, %2 : i64, !llvm.ptr
18+
omp.yield
19+
}
20+
}
21+
llvm.return
22+
}
23+
24+
// -----
25+
26+
//CHECK-LABEL: define void @_QPtest(ptr %0, ptr %1) {
27+
llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fir.bindc_name = "a"}) {
28+
%0 = llvm.mlir.constant(1 : i32) : i32
29+
%1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
30+
%2 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
31+
%3 = llvm.mlir.constant(1 : i64) : i64
32+
%4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
33+
%6 = llvm.load %arg0 : !llvm.ptr -> i32
34+
//CHECK: %[[A_VAL1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
35+
//CHECK: %[[A_VAL2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
36+
omp.simd nontemporal(%arg1 : !llvm.ptr) {
37+
omp.loop_nest (%arg2) : i32 = (%0) to (%6) inclusive step (%0) {
38+
llvm.store %arg2, %4 : i32, !llvm.ptr
39+
//CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL2]], ptr %1, i32 48, i1 false)
40+
%7 = llvm.mlir.constant(48 : i32) : i32
41+
"llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
42+
%8 = llvm.load %4 : !llvm.ptr -> i32
43+
%9 = llvm.sext %8 : i32 to i64
44+
%10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
45+
%11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
46+
%12 = llvm.mlir.constant(0 : index) : i64
47+
%13 = llvm.getelementptr %2[0, 7, %12, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
48+
%14 = llvm.load %13 : !llvm.ptr -> i64
49+
%15 = llvm.getelementptr %2[0, 7, %12, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
50+
%16 = llvm.load %15 : !llvm.ptr -> i64
51+
%17 = llvm.getelementptr %2[0, 7, %12, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
52+
%18 = llvm.load %17 : !llvm.ptr -> i64
53+
%19 = llvm.mlir.constant(0 : i64) : i64
54+
%20 = llvm.sub %9, %14 overflow<nsw> : i64
55+
%21 = llvm.mul %20, %3 overflow<nsw> : i64
56+
%22 = llvm.mul %21, %3 overflow<nsw> : i64
57+
%23 = llvm.add %22,%19 overflow<nsw> : i64
58+
%24 = llvm.mul %3, %16 overflow<nsw> : i64
59+
//CHECK: %[[VAL1:.*]] = getelementptr float, ptr {{.*}}, i64 %{{.*}}
60+
//CHECK: %[[LOAD_A:.*]] = load float, ptr %[[VAL1]], align 4, !nontemporal
61+
//CHECK: %[[RES:.*]] = fadd contract float %[[LOAD_A]], 2.000000e+01
62+
%25 = llvm.getelementptr %11[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
63+
%26 = llvm.load %25 : !llvm.ptr -> f32
64+
%27 = llvm.mlir.constant(2.000000e+01 : f32) : f32
65+
%28 = llvm.fadd %26, %27 {fastmathFlags = #llvm.fastmath<contract>} : f32
66+
//CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL1]], ptr %1, i32 48, i1 false)
67+
%29 = llvm.mlir.constant(48 : i32) : i32
68+
"llvm.intr.memcpy"(%1, %arg1, %29) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
69+
%30 = llvm.load %4 : !llvm.ptr -> i32
70+
%31 = llvm.sext %30 : i32 to i64
71+
%32 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
72+
%33 = llvm.load %32 : !llvm.ptr -> !llvm.ptr
73+
%34 = llvm.mlir.constant(0 : index) : i64
74+
%35 = llvm.getelementptr %1[0, 7, %34, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
75+
%36 = llvm.load %35 : !llvm.ptr -> i64
76+
%37 = llvm.getelementptr %1[0, 7, %34, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
77+
%38 = llvm.load %37 : !llvm.ptr -> i64
78+
%39 = llvm.getelementptr %1[0, 7, %34, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
79+
%40 = llvm.load %39 : !llvm.ptr -> i64
80+
%41 = llvm.sub %31, %36 overflow<nsw> : i64
81+
%42 = llvm.mul %41, %3 overflow<nsw> : i64
82+
%43 = llvm.mul %42, %3 overflow<nsw> : i64
83+
%44 = llvm.add %43,%19 overflow<nsw> : i64
84+
%45 = llvm.mul %3, %38 overflow<nsw> : i64
85+
//CHECK: %[[VAL2:.*]] = getelementptr float, ptr %{{.*}}, i64 %{{.*}}
86+
//CHECK: store float %[[RES]], ptr %[[VAL2]], align 4, !nontemporal
87+
%46 = llvm.getelementptr %33[%44] : (!llvm.ptr, i64) -> !llvm.ptr, f32
88+
llvm.store %28, %46 : f32, !llvm.ptr
89+
omp.yield
90+
}
91+
}
92+
llvm.return
93+
}
94+
// -----
95+
96+

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,6 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
155155

156156
// -----
157157

158-
llvm.func @simd_nontemporal(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
159-
// expected-error@below {{not yet implemented: Unhandled clause nontemporal in omp.simd operation}}
160-
// expected-error@below {{LLVM Translation failed for operation: omp.simd}}
161-
omp.simd nontemporal(%x : !llvm.ptr) {
162-
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
163-
omp.yield
164-
}
165-
}
166-
llvm.return
167-
}
168-
169-
// -----
170-
171158
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
172159
^bb0(%arg0: !llvm.ptr):
173160
%0 = llvm.mlir.constant(1 : i32) : i32

0 commit comments

Comments
 (0)