Skip to content

Commit 2cf914a

Browse files
AlexeySotkinsvenvh
authored andcommitted
Handle @llvm.memset.* with non-constant arguments (intel#696)
* Handle @llvm.memset.* with non-constant arguments There is no SPIR-V counterpart for @llvm.memset.* intrinsic. Cases with constant value and length arguments are emulated via "storing" a constant array to the destination. For other cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the intrinsic to a loop via expandMemSetAsLoop() from llvm/Transforms/Utils/LowerMemIntrinsics.h. During reverse translation from SPIR-V to LLVM IR we can detect @spirv.llvm_memset_* and replace it with @llvm.memset. Signed-off-by: Alexey Sotkin <[email protected]> Co-authored-by: Sven van Haastregt <[email protected]>
1 parent 86649a8 commit 2cf914a

File tree

3 files changed

+134
-13
lines changed

3 files changed

+134
-13
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,21 +2786,33 @@ Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF) {
27862786
auto IsKernel = isKernel(BF);
27872787
auto Linkage = IsKernel ? GlobalValue::ExternalLinkage : transLinkageType(BF);
27882788
FunctionType *FT = dyn_cast<FunctionType>(transType(BF->getFunctionType()));
2789-
Function *F = cast<Function>(
2790-
mapValue(BF, Function::Create(FT, Linkage, BF->getName(), M)));
2789+
std::string FuncName = BF->getName();
2790+
StringRef FuncNameRef(FuncName);
2791+
// Transform "@spirv.llvm_memset_p0i8_i32.volatile" to @llvm.memset.p0i8.i32
2792+
// assuming llvm.memset is supported by the device compiler. If this
2793+
// assumption is not safe, we should have a command line option to control
2794+
// this behavior.
2795+
if (FuncNameRef.consume_front("spirv.")) {
2796+
FuncNameRef.consume_back(".volatile");
2797+
FuncName = FuncNameRef.str();
2798+
std::replace(FuncName.begin(), FuncName.end(), '_', '.');
2799+
}
2800+
Function *F = M->getFunction(FuncName);
2801+
if (!F)
2802+
F = Function::Create(FT, Linkage, FuncName, M);
2803+
F = cast<Function>(mapValue(BF, F));
27912804
mapFunction(BF, F);
27922805

2806+
if (F->isIntrinsic())
2807+
return F;
2808+
2809+
F->setCallingConv(IsKernel ? CallingConv::SPIR_KERNEL
2810+
: CallingConv::SPIR_FUNC);
27932811
if (BF->hasDecorate(DecorationReferencedIndirectlyINTEL))
27942812
F->addFnAttr("referenced-indirectly");
2795-
2796-
if (!F->isIntrinsic()) {
2797-
F->setCallingConv(IsKernel ? CallingConv::SPIR_KERNEL
2798-
: CallingConv::SPIR_FUNC);
2799-
if (isFuncNoUnwind())
2800-
F->addFnAttr(Attribute::NoUnwind);
2801-
foreachFuncCtlMask(BF,
2802-
[&](Attribute::AttrKind Attr) { F->addFnAttr(Attr); });
2803-
}
2813+
if (isFuncNoUnwind())
2814+
F->addFnAttr(Attribute::NoUnwind);
2815+
foreachFuncCtlMask(BF, [&](Attribute::AttrKind Attr) { F->addFnAttr(Attr); });
28042816

28052817
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
28062818
++I) {

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "llvm/IR/Operator.h"
4747
#include "llvm/Pass.h"
4848
#include "llvm/Support/Debug.h"
49+
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h" // expandMemSetAsLoop()
4950

5051
#include <set>
5152
#include <vector>
@@ -75,6 +76,15 @@ class SPIRVRegularizeLLVM : public ModulePass {
7576
void lowerFuncPtr(Function *F, Op OC);
7677
void lowerFuncPtr(Module *M);
7778

79+
/// There is no SPIR-V counterpart for @llvm.memset.* intrinsic. Cases with
80+
/// constant value and length arguments are emulated via "storing" a constant
81+
/// array to the destination. For other cases we wrap the intrinsic in
82+
/// @spirv.llvm_memset_* function and expand the intrinsic to a loop via
83+
/// expandMemSetAsLoop() from llvm/Transforms/Utils/LowerMemIntrinsics.h
84+
/// During reverse translation from SPIR-V to LLVM IR we can detect
85+
/// @spirv.llvm_memset_* and replace it with @llvm.memset.
86+
void lowerMemset(MemSetInst *MSI);
87+
7888
static char ID;
7989

8090
private:
@@ -84,6 +94,49 @@ class SPIRVRegularizeLLVM : public ModulePass {
8494

8595
char SPIRVRegularizeLLVM::ID = 0;
8696

97+
void SPIRVRegularizeLLVM::lowerMemset(MemSetInst *MSI) {
98+
if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
99+
return; // To be handled in LLVMToSPIRV::transIntrinsicInst
100+
Function *IntrinsicFunc = MSI->getCalledFunction();
101+
assert(IntrinsicFunc && "Missing function");
102+
std::string FuncName = IntrinsicFunc->getName().str();
103+
std::replace(FuncName.begin(), FuncName.end(), '.', '_');
104+
FuncName = "spirv." + FuncName;
105+
if (MSI->isVolatile())
106+
FuncName += ".volatile";
107+
108+
// Redirect @llvm.memset.* call to @spirv.llvm_memset_*
109+
Function *F = M->getFunction(FuncName);
110+
if (F) {
111+
// This function is already linked in.
112+
MSI->setCalledFunction(F);
113+
return;
114+
}
115+
// TODO copy arguments attributes: nocapture writeonly.
116+
FunctionCallee FC = M->getOrInsertFunction(FuncName, MSI->getFunctionType());
117+
MSI->setCalledFunction(FC);
118+
119+
F = dyn_cast<Function>(FC.getCallee());
120+
assert(F && "must be a function!");
121+
Argument *Dest = F->getArg(0);
122+
Argument *Val = F->getArg(1);
123+
Argument *Len = F->getArg(2);
124+
Argument *IsVolatile = F->getArg(3);
125+
Dest->setName("dest");
126+
Val->setName("val");
127+
Len->setName("len");
128+
IsVolatile->setName("isvolatile");
129+
IsVolatile->addAttr(Attribute::ImmArg);
130+
BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
131+
IRBuilder<> IRB(EntryBB);
132+
auto *MemSet =
133+
IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(), MSI->isVolatile());
134+
IRB.CreateRetVoid();
135+
expandMemSetAsLoop(cast<MemSetInst>(MemSet));
136+
MemSet->eraseFromParent();
137+
return;
138+
}
139+
87140
bool SPIRVRegularizeLLVM::runOnModule(Module &Module) {
88141
M = &Module;
89142
Ctx = &M->getContext();
@@ -115,8 +168,11 @@ bool SPIRVRegularizeLLVM::regularize() {
115168
if (auto Call = dyn_cast<CallInst>(&II)) {
116169
Call->setTailCall(false);
117170
Function *CF = Call->getCalledFunction();
118-
if (CF && CF->isIntrinsic())
171+
if (CF && CF->isIntrinsic()) {
119172
removeFnAttr(Call, Attribute::NoUnwind);
173+
if (auto *MSI = dyn_cast<MemSetInst>(Call))
174+
lowerMemset(MSI);
175+
}
120176
}
121177

122178
// Remove optimization info not supported by SPIRV

llvm-spirv/test/transcoding/llvm.memset.ll

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
; RUN: spirv-val %t.spv
77
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
88

9+
; CHECK-SPIRV: Decorate [[#NonConstMemset:]] LinkageAttributes "spirv.llvm_memset_p3i8_i32"
910
; CHECK-SPIRV: TypeInt [[Int8:[0-9]+]] 8 0
1011
; CHECK-SPIRV: Constant {{[0-9]+}} [[Lenmemset21:[0-9]+]] 4
1112
; CHECK-SPIRV: Constant {{[0-9]+}} [[Lenmemset0:[0-9]+]] 12
@@ -19,6 +20,7 @@
1920
; CHECK-SPIRV: Variable {{[0-9]+}} [[Val:[0-9]+]] 0 [[Init]]
2021
; CHECK-SPIRV: 7 ConstantComposite [[Int8x4]] [[InitComp:[0-9]+]] [[Const21]] [[Const21]] [[Const21]] [[Const21]]
2122
; CHECK-SPIRV: Variable {{[0-9]+}} [[ValComp:[0-9]+]] 0 [[InitComp]]
23+
; CHECK-SPIRV: ConstantFalse [[#]] [[#False:]]
2224

2325
; CHECK-SPIRV: Bitcast [[Int8Ptr]] [[Target:[0-9]+]] {{[0-9]+}}
2426
; CHECK-SPIRV: Bitcast [[Int8PtrConst]] [[Source:[0-9]+]] [[Val]]
@@ -27,6 +29,31 @@
2729
; CHECK-SPIRV: Bitcast [[Int8PtrConst]] [[SourceComp:[0-9]+]] [[ValComp]]
2830
; CHECK-SPIRV: CopyMemorySized {{[0-9]+}} [[SourceComp]] [[Lenmemset21]] 2 4
2931

32+
; CHECK-SPIRV: FunctionCall [[#]] [[#]] [[#NonConstMemset]] [[#]] [[#]] [[#]] [[#False]]
33+
34+
; CHECK-SPIRV: Function [[#]] [[#NonConstMemset]]
35+
; CHECK-SPIRV: FunctionParameter [[#]] [[#Dest:]]
36+
; CHECK-SPIRV: FunctionParameter [[#]] [[#Value:]]
37+
; CHECK-SPIRV: FunctionParameter [[#]] [[#Len:]]
38+
; CHECK-SPIRV: FunctionParameter [[#]] [[#Volatile:]]
39+
40+
; CHECK-SPIRV: Label [[#Entry:]]
41+
; CHECK-SPIRV: IEqual [[#]] [[#IsZeroLen:]] [[#Zero:]] [[#Len]]
42+
; CHECK-SPIRV: BranchConditional [[#IsZeroLen]] [[#End:]] [[#WhileBody:]]
43+
44+
; CHECK-SPIRV: Label [[#WhileBody]]
45+
; CHECK-SPIRV: Phi [[#]] [[#Offset:]] [[#Zero]] [[#Entry]] [[#OffsetInc:]] [[#WhileBody]]
46+
; CHECK-SPIRV: InBoundsPtrAccessChain [[#]] [[#Ptr:]] [[#Dest]] [[#Offset]]
47+
; CHECK-SPIRV: Store [[#Ptr]] [[#Value]] 2 1
48+
; CHECK-SPIRV: IAdd [[#]] [[#OffsetInc]] [[#Offset]] [[#One:]]
49+
; CHECK-SPIRV: ULessThan [[#]] [[#NotEnd:]] [[#OffsetInc]] [[#Len]]
50+
; CHECK-SPIRV: BranchConditional [[#NotEnd]] [[#WhileBody]] [[#End]]
51+
52+
; CHECK-SPIRV: Label [[#End]]
53+
; CHECK-SPIRV: Return
54+
55+
; CHECK-SPIRV: FunctionEnd
56+
3057
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
3158
target triple = "spir"
3259

@@ -36,14 +63,34 @@ target triple = "spir"
3663
; CHECK-LLVM: internal unnamed_addr addrspace(2) constant [4 x i8] c"\15\15\15\15"
3764

3865
; Function Attrs: nounwind
39-
define spir_func void @_Z5foo11v(%struct.S1 addrspace(4)* noalias nocapture sret %agg.result) #0 {
66+
define spir_func void @_Z5foo11v(%struct.S1 addrspace(4)* noalias nocapture sret %agg.result, i32 %s1, i64 %s2, i8 %v) #0 {
4067
%x = alloca [4 x i8]
4168
%x.bc = bitcast [4 x i8]* %x to i8*
4269
%1 = bitcast %struct.S1 addrspace(4)* %agg.result to i8 addrspace(4)*
4370
tail call void @llvm.memset.p4i8.i32(i8 addrspace(4)* align 4 %1, i8 0, i32 12, i1 false)
4471
; CHECK-LLVM: call void @llvm.memset.p4i8.i32(i8 addrspace(4)* align 4 %1, i8 0, i32 12, i1 false)
4572
tail call void @llvm.memset.p0i8.i32(i8* align 4 %x.bc, i8 21, i32 4, i1 false)
4673
; CHECK-LLVM: call void @llvm.memcpy.p0i8.p2i8.i32(i8* align 4 %x.bc, i8 addrspace(2)* align 4 %3, i32 4, i1 false)
74+
75+
; non-const value
76+
tail call void @llvm.memset.p0i8.i32(i8* align 4 %x.bc, i8 %v, i32 3, i1 false)
77+
; CHECK-LLVM: call void @llvm.memset.p0i8.i32(i8* %x.bc, i8 %v, i32 3, i1 false)
78+
79+
; non-const value and size
80+
tail call void @llvm.memset.p0i8.i32(i8* align 4 %x.bc, i8 %v, i32 %s1, i1 false)
81+
; CHECK-LLVM: call void @llvm.memset.p0i8.i32(i8* %x.bc, i8 %v, i32 %s1, i1 false)
82+
83+
; Address spaces, non-const value and size
84+
%a = addrspacecast i8 addrspace(4)* %1 to i8 addrspace(3)*
85+
tail call void @llvm.memset.p3i8.i32(i8 addrspace(3)* align 4 %a, i8 %v, i32 %s1, i1 false)
86+
; CHECK-LLVM: call void @llvm.memset.p3i8.i32(i8 addrspace(3)* %a, i8 %v, i32 %s1, i1 false)
87+
%b = addrspacecast i8 addrspace(4)* %1 to i8 addrspace(1)*
88+
tail call void @llvm.memset.p1i8.i64(i8 addrspace(1)* align 4 %b, i8 %v, i64 %s2, i1 false)
89+
; CHECK-LLVM: call void @llvm.memset.p1i8.i64(i8 addrspace(1)* %b, i8 %v, i64 %s2, i1 false)
90+
91+
; Volatile
92+
tail call void @llvm.memset.p1i8.i64(i8 addrspace(1)* align 4 %b, i8 %v, i64 %s2, i1 true)
93+
; CHECK-LLVM: call void @llvm.memset.p1i8.i64(i8 addrspace(1)* %b, i8 %v, i64 %s2, i1 true)
4794
ret void
4895
}
4996

@@ -53,6 +100,12 @@ declare void @llvm.memset.p4i8.i32(i8 addrspace(4)* nocapture, i8, i32, i1) #1
53100
; Function Attrs: nounwind
54101
declare void @llvm.memset.p0i8.i32(i8* nocapture, i8, i32, i1) #1
55102

103+
; Function Attrs: nounwind
104+
declare void @llvm.memset.p3i8.i32(i8 addrspace(3)*, i8, i32, i1) #1
105+
106+
; Function Attrs: nounwind
107+
declare void @llvm.memset.p1i8.i64(i8 addrspace(1)*, i8, i64, i1) #1
108+
56109
attributes #0 = { nounwind "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
57110
attributes #1 = { nounwind }
58111

0 commit comments

Comments
 (0)