Skip to content

Commit 4be9289

Browse files
committed
[MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for SIMD directive
1 parent 00b50c9 commit 4be9289

File tree

5 files changed

+191
-21
lines changed

5 files changed

+191
-21
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ class OpenMPIRBuilder {
12201220
void applySimd(CanonicalLoopInfo *Loop,
12211221
MapVector<Value *, Value *> AlignedVars, Value *IfCond,
12221222
omp::OrderKind Order, ConstantInt *Simdlen,
1223-
ConstantInt *Safelen);
1223+
ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
12241224

12251225
/// Generator for '#omp flush'
12261226
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5279,10 +5279,86 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
52795279
return 0;
52805280
}
52815281

5282+
static void appendNontemporalVars(BasicBlock *Block,
5283+
SmallVectorImpl<Value *> &NontemporalVars) {
5284+
for (Instruction &I : *Block) {
5285+
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
5286+
if (CI->getIntrinsicID() == Intrinsic::memcpy) {
5287+
llvm::Value *DestPtr = CI->getArgOperand(0);
5288+
llvm::Value *SrcPtr = CI->getArgOperand(1);
5289+
for (const llvm::Value *Var : NontemporalVars) {
5290+
if (Var == SrcPtr) {
5291+
NontemporalVars.push_back(DestPtr);
5292+
break;
5293+
}
5294+
}
5295+
}
5296+
}
5297+
}
5298+
}
5299+
5300+
/** Attach nontemporal metadata to the load/store instructions of nontemporal
5301+
* variables of \p Block
5302+
* Nontemporal variables may be a scalar, fixed size or allocatable
5303+
* or pointer array
5304+
*
5305+
* Example scenarios for nontemporal variables:
5306+
* Case 1: Scalar variable
5307+
* If the nontemporal variable is a scalar, it is allocated on stack.Load and
5308+
* store instructions directly access the alloca pointer of the scalar
5309+
* variable for fetching information about scalar variable or writing
5310+
* into the scalar variable. Mark those load and store instructions as
5311+
* non-temporal.
5312+
*
5313+
* Case 2: Fixed Size array
5314+
* If the nontemporal variable is a fixed-size array, it is allocated
5315+
* as a contiguous block of memory. It uses one GEP instruction, to compute the
5316+
* address of each individual array elements and perform load or store
5317+
* operation on it. Mark those load and store instructions as non-temporal.
5318+
*
5319+
* Case 3: Allocatable array
5320+
* For an allocatable array, which might involve runtime type descriptor,
5321+
* needs to navigate through descriptors using two or more GEP and load
5322+
* instructions to compute the address of each individual element in an array.
5323+
* Mark those load or store which access the individual array elements as
5324+
* non-temporal.
5325+
*/
5326+
static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
5327+
SmallVectorImpl<Value *> &NontemporalVars) {
5328+
appendNontemporalVars(Block, NontemporalVars);
5329+
for (Instruction &I : *Block) {
5330+
llvm::Value *mem_ptr = nullptr;
5331+
bool MetadataFlag = true;
5332+
if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
5333+
if (!(li->getType()->isPointerTy()))
5334+
mem_ptr = li->getPointerOperand();
5335+
} else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
5336+
mem_ptr = si->getPointerOperand();
5337+
if (mem_ptr) {
5338+
while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
5339+
if (llvm::GetElementPtrInst *gep =
5340+
dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
5341+
llvm::Type *sourceType = gep->getSourceElementType();
5342+
if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
5343+
!(gep->hasAllZeroIndices())) {
5344+
MetadataFlag = false;
5345+
break;
5346+
}
5347+
mem_ptr = gep->getPointerOperand();
5348+
} else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
5349+
mem_ptr = li->getPointerOperand();
5350+
}
5351+
if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
5352+
I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
5353+
}
5354+
}
5355+
}
5356+
52825357
void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
52835358
MapVector<Value *, Value *> AlignedVars,
52845359
Value *IfCond, OrderKind Order,
5285-
ConstantInt *Simdlen, ConstantInt *Safelen) {
5360+
ConstantInt *Simdlen, ConstantInt *Safelen,
5361+
ArrayRef<Value *> NontemporalVarsIn) {
52865362
LLVMContext &Ctx = Builder.getContext();
52875363

52885364
Function *F = CanonicalLoop->getFunction();
@@ -5379,6 +5455,13 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
53795455
}
53805456

53815457
addLoopMetadata(CanonicalLoop, LoopMDList);
5458+
SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
5459+
// Set nontemporal metadata to load and stores of nontemporal values
5460+
if (NontemporalVars.size()) {
5461+
MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
5462+
for (BasicBlock *BB : Reachable)
5463+
addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
5464+
}
53825465
}
53835466

53845467
/// Create the TargetMachine object to query the backend for optimization

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
191191
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
192192
result = todo("linear");
193193
};
194-
auto checkNontemporal = [&todo](auto op, LogicalResult &result) {
195-
if (!op.getNontemporalVars().empty())
196-
result = todo("nontemporal");
197-
};
194+
198195
auto checkNowait = [&todo](auto op, LogicalResult &result) {
199196
if (op.getNowait())
200197
result = todo("nowait");
@@ -274,7 +271,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
274271
.Case([&](omp::SimdOp op) {
275272
checkAligned(op, result);
276273
checkLinear(op, result);
277-
checkNontemporal(op, result);
278274
checkPrivate(op, result);
279275
checkReduction(op, result);
280276
})
@@ -2231,11 +2227,19 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
22312227

22322228
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
22332229
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
2230+
2231+
llvm::SmallVector<llvm::Value *> nontemporalVars;
2232+
mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
2233+
for (mlir::Value nontemporal : nontemporals) {
2234+
llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
2235+
nontemporalVars.push_back(nt);
2236+
}
2237+
22342238
ompBuilder->applySimd(loopInfo, alignedVars,
22352239
simdOp.getIfExpr()
22362240
? moduleTranslation.lookupValue(simdOp.getIfExpr())
22372241
: nullptr,
2238-
order, simdlen, safelen);
2242+
order, simdlen, safelen, nontemporalVars);
22392243

22402244
builder.restoreIP(afterIP);
22412245
return success();
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// -----
4+
// CHECK-LABEL: @simd_nontemporal
5+
llvm.func @simd_nontemporal() {
6+
%0 = llvm.mlir.constant(10 : i64) : i64
7+
%1 = llvm.mlir.constant(1 : i64) : i64
8+
%2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
9+
%3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
10+
//CHECK: %[[A_ADDR:.*]] = alloca i64, i64 1, align 8
11+
//CHECK: %[[B_ADDR:.*]] = alloca i64, i64 1, align 8
12+
//CHECK: %[[B:.*]] = load i64, ptr %[[B_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
13+
//CHECK: store i64 %[[B]], ptr %[[A_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
14+
omp.simd nontemporal(%2, %3 : !llvm.ptr, !llvm.ptr) {
15+
omp.loop_nest (%arg0) : i64 = (%1) to (%0) inclusive step (%1) {
16+
%4 = llvm.load %3 : !llvm.ptr -> i64
17+
llvm.store %4, %2 : i64, !llvm.ptr
18+
omp.yield
19+
}
20+
}
21+
llvm.return
22+
}
23+
24+
// -----
25+
26+
//CHECK-LABEL: define void @_QPtest(ptr %0, ptr %1) {
27+
llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fir.bindc_name = "a"}) {
28+
%0 = llvm.mlir.constant(1 : i32) : i32
29+
%1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
30+
%2 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
31+
%3 = llvm.mlir.constant(1 : i64) : i64
32+
%4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
33+
%6 = llvm.load %arg0 : !llvm.ptr -> i32
34+
//CHECK: %[[A_VAL1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
35+
//CHECK: %[[A_VAL2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
36+
omp.simd nontemporal(%arg1 : !llvm.ptr) {
37+
omp.loop_nest (%arg2) : i32 = (%0) to (%6) inclusive step (%0) {
38+
llvm.store %arg2, %4 : i32, !llvm.ptr
39+
//CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL2]], ptr %1, i32 48, i1 false)
40+
%7 = llvm.mlir.constant(48 : i32) : i32
41+
"llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
42+
%8 = llvm.load %4 : !llvm.ptr -> i32
43+
%9 = llvm.sext %8 : i32 to i64
44+
%10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
45+
%11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
46+
%12 = llvm.mlir.constant(0 : index) : i64
47+
%13 = llvm.getelementptr %2[0, 7, %12, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
48+
%14 = llvm.load %13 : !llvm.ptr -> i64
49+
%15 = llvm.getelementptr %2[0, 7, %12, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
50+
%16 = llvm.load %15 : !llvm.ptr -> i64
51+
%17 = llvm.getelementptr %2[0, 7, %12, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
52+
%18 = llvm.load %17 : !llvm.ptr -> i64
53+
%19 = llvm.mlir.constant(0 : i64) : i64
54+
%20 = llvm.sub %9, %14 overflow<nsw> : i64
55+
%21 = llvm.mul %20, %3 overflow<nsw> : i64
56+
%22 = llvm.mul %21, %3 overflow<nsw> : i64
57+
%23 = llvm.add %22,%19 overflow<nsw> : i64
58+
%24 = llvm.mul %3, %16 overflow<nsw> : i64
59+
//CHECK: %[[VAL1:.*]] = getelementptr float, ptr {{.*}}, i64 %{{.*}}
60+
//CHECK: %[[LOAD_A:.*]] = load float, ptr %[[VAL1]], align 4, !nontemporal
61+
//CHECK: %[[RES:.*]] = fadd contract float %[[LOAD_A]], 2.000000e+01
62+
%25 = llvm.getelementptr %11[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
63+
%26 = llvm.load %25 : !llvm.ptr -> f32
64+
%27 = llvm.mlir.constant(2.000000e+01 : f32) : f32
65+
%28 = llvm.fadd %26, %27 {fastmathFlags = #llvm.fastmath<contract>} : f32
66+
//CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL1]], ptr %1, i32 48, i1 false)
67+
%29 = llvm.mlir.constant(48 : i32) : i32
68+
"llvm.intr.memcpy"(%1, %arg1, %29) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
69+
%30 = llvm.load %4 : !llvm.ptr -> i32
70+
%31 = llvm.sext %30 : i32 to i64
71+
%32 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
72+
%33 = llvm.load %32 : !llvm.ptr -> !llvm.ptr
73+
%34 = llvm.mlir.constant(0 : index) : i64
74+
%35 = llvm.getelementptr %1[0, 7, %34, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
75+
%36 = llvm.load %35 : !llvm.ptr -> i64
76+
%37 = llvm.getelementptr %1[0, 7, %34, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
77+
%38 = llvm.load %37 : !llvm.ptr -> i64
78+
%39 = llvm.getelementptr %1[0, 7, %34, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
79+
%40 = llvm.load %39 : !llvm.ptr -> i64
80+
%41 = llvm.sub %31, %36 overflow<nsw> : i64
81+
%42 = llvm.mul %41, %3 overflow<nsw> : i64
82+
%43 = llvm.mul %42, %3 overflow<nsw> : i64
83+
%44 = llvm.add %43,%19 overflow<nsw> : i64
84+
%45 = llvm.mul %3, %38 overflow<nsw> : i64
85+
//CHECK: %[[VAL2:.*]] = getelementptr float, ptr %{{.*}}, i64 %{{.*}}
86+
//CHECK: store float %[[RES]], ptr %[[VAL2]], align 4, !nontemporal
87+
%46 = llvm.getelementptr %33[%44] : (!llvm.ptr, i64) -> !llvm.ptr, f32
88+
llvm.store %28, %46 : f32, !llvm.ptr
89+
omp.yield
90+
}
91+
}
92+
llvm.return
93+
}
94+
// -----
95+
96+

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,6 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
155155

156156
// -----
157157

158-
llvm.func @simd_nontemporal(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
159-
// expected-error@below {{not yet implemented: Unhandled clause nontemporal in omp.simd operation}}
160-
// expected-error@below {{LLVM Translation failed for operation: omp.simd}}
161-
omp.simd nontemporal(%x : !llvm.ptr) {
162-
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
163-
omp.yield
164-
}
165-
}
166-
llvm.return
167-
}
168-
169-
// -----
170-
171158
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
172159
^bb0(%arg0: !llvm.ptr):
173160
%0 = llvm.mlir.constant(1 : i32) : i32

0 commit comments

Comments
 (0)