Skip to content

Commit 2d62e51

Browse files
author
Artem Gindinson
authored
[SYCL] Initial printf support for non-constant AS format strings (#5069)
Allow generic address space strings into `experimental::printf` and, consequently, into `__spirv_ocl_printf` declarations. To mitigate the lack of device BE support for generic address space literals, we implement a pass to move the literal argument into constant address space whenever possible. This is a temporary solution; long-term, we should replace this with proper support of generic address-spaced format strings at the level of SPIR-V translation and device backends. Tested in intel/llvm-test-suite#569. Signed-off-by: Artem Gindinson <[email protected]>
1 parent 8c5b701 commit 2d62e51

File tree

17 files changed

+861
-5
lines changed

17 files changed

+861
-5
lines changed

.github/CODEOWNERS

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,8 @@ sycl/doc/extensions/ExplicitSIMD/ @kbobrovs @v-klochkov @kychendev
109109
llvm/lib/Transforms/Instrumentation/SPIRITTAnnotations.cpp @MrSidims @vzakhari
110110
llvm/include/llvm/Transforms/Instrumentation/SPIRITTAnnotations.h @MrSidims @vzakhari
111111
llvm/test/Transforms/SPIRITTAnnotations/* @MrSidims @vzakhari
112+
113+
# Generic address space support for printf
114+
llvm/lib/SYCLLowerIR/MutatePrintfAddrspace.cpp @AGindinson @AlexeySachkov @mlychkov
115+
llvm/include/llvm/SYCLLowerIR/MutatePrintfAddrspace.h @AGindinson @AlexeySachkov @mlychkov
116+
llvm/test/SYCLLowerIR/printf_addrspace/* @AGindinson @AlexeySachkov @mlychkov

clang/lib/CodeGen/BackendUtil.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "llvm/Passes/StandardInstrumentations.h"
4545
#include "llvm/SYCLLowerIR/ESIMDVerifier.h"
4646
#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
47+
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
4748
#include "llvm/Support/BuryPointer.h"
4849
#include "llvm/Support/CommandLine.h"
4950
#include "llvm/Support/MemoryBuffer.h"
@@ -1059,6 +1060,7 @@ void EmitAssemblyHelper::EmitAssemblyWithLegacyPassManager(
10591060
if (CodeGenOpts.DisableLLVMPasses)
10601061
PerModulePasses.add(createAlwaysInlinerLegacyPass(false));
10611062
PerModulePasses.add(createSYCLLowerWGLocalMemoryLegacyPass());
1063+
PerModulePasses.add(createSYCLMutatePrintfAddrspaceLegacyPass());
10621064
}
10631065

10641066
switch (Action) {
@@ -1476,6 +1478,9 @@ void EmitAssemblyHelper::RunOptimizationPipeline(
14761478
MPM.addPass(ModuleMemProfilerPass());
14771479
}
14781480
}
1481+
if (LangOpts.SYCLIsDevice) {
1482+
MPM.addPass(SYCLMutatePrintfAddrspacePass());
1483+
}
14791484

14801485
// Add a verifier pass if requested. We don't have to do this if the action
14811486
// requires code generation because there will already be a verifier pass in

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ void initializeStripSymbolsPass(PassRegistry&);
440440
void initializeStructurizeCFGLegacyPassPass(PassRegistry &);
441441
void initializeSYCLLowerWGScopeLegacyPassPass(PassRegistry &);
442442
void initializeSYCLLowerESIMDLegacyPassPass(PassRegistry &);
443+
void initializeSYCLMutatePrintfAddrspaceLegacyPassPass(PassRegistry &);
443444
void initializeSPIRITTAnnotationsLegacyPassPass(PassRegistry &);
444445
void initializeESIMDLowerLoadStorePass(PassRegistry &);
445446
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);

llvm/include/llvm/LinkAllPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "llvm/IR/Function.h"
3939
#include "llvm/IR/IRPrintingPasses.h"
4040
#include "llvm/SYCLLowerIR/ESIMDVerifier.h"
41+
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
4142
#include "llvm/Support/Valgrind.h"
4243
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
4344
#include "llvm/Transforms/IPO.h"
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===------- MutatePrintfAddrspace.h - SYCL printf AS mutation Pass -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// A transformation pass which detects non-constant address space
10+
// literals usage for the first argument of SYCL experimental printf
11+
// function, and moves the string literal to constant address
12+
// space. This a temporary solution for printf's support of generic
13+
// address space literals; the pass should be dropped once SYCL device
14+
// backends learn to handle the generic address-spaced argument properly.
15+
//===----------------------------------------------------------------------===//
16+
17+
#pragma once
18+
19+
#include "llvm/IR/Module.h"
20+
#include "llvm/IR/PassManager.h"
21+
22+
namespace llvm {
23+
24+
class SYCLMutatePrintfAddrspacePass
25+
: public PassInfoMixin<SYCLMutatePrintfAddrspacePass> {
26+
public:
27+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
28+
};
29+
30+
ModulePass *createSYCLMutatePrintfAddrspaceLegacyPass();
31+
32+
} // namespace llvm

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#include "llvm/SYCLLowerIR/ESIMDVerifier.h"
8080
#include "llvm/SYCLLowerIR/LowerESIMD.h"
8181
#include "llvm/SYCLLowerIR/LowerWGScope.h"
82+
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
8283
#include "llvm/Support/CommandLine.h"
8384
#include "llvm/Support/Debug.h"
8485
#include "llvm/Support/ErrorHandling.h"

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ MODULE_PASS("pseudo-probe-update", PseudoProbeUpdatePass())
119119
MODULE_PASS("LowerESIMD", SYCLLowerESIMDPass())
120120
MODULE_PASS("ESIMDLowerVecArg", ESIMDLowerVecArgPass())
121121
MODULE_PASS("esimd-verifier", ESIMDVerifierPass())
122+
MODULE_PASS("SYCLMutatePrintfAddrspace", SYCLMutatePrintfAddrspacePass())
122123
MODULE_PASS("SPIRITTAnnotations", SPIRITTAnnotationsPass())
123124
MODULE_PASS("deadargelim-sycl", DeadArgumentEliminationSYCLPass())
124125
#undef MODULE_PASS

llvm/lib/SYCLLowerIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
5454
LowerESIMDVecArg.cpp
5555
LowerWGLocalMemory.cpp
5656
ESIMDVerifier.cpp
57+
MutatePrintfAddrspace.cpp
5758

5859
ADDITIONAL_HEADER_DIRS
5960
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLLowerIR
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
//===------ MutatePrintfAddrspace.cpp - SYCL printf AS mutation Pass ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// A transformation pass which detects non-constant address space
10+
// literals usage for the first argument of SYCL experimental printf
11+
// function, and moves the string literal to constant address
12+
// space. This a temporary solution for printf's support of generic
13+
// address space literals; the pass should be dropped once SYCL device
14+
// backends learn to handle the generic address-spaced argument properly.
15+
//===----------------------------------------------------------------------===//
16+
17+
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
18+
19+
#include "llvm/Analysis/ValueTracking.h"
20+
#include "llvm/IR/IRBuilder.h"
21+
#include "llvm/IR/Instructions.h"
22+
#include "llvm/InitializePasses.h"
23+
24+
using namespace llvm;
25+
26+
namespace {
27+
// Wrapper for the pass to make it working with the old pass manager
28+
class SYCLMutatePrintfAddrspaceLegacyPass : public ModulePass {
29+
public:
30+
static char ID;
31+
SYCLMutatePrintfAddrspaceLegacyPass() : ModulePass(ID) {
32+
initializeSYCLMutatePrintfAddrspaceLegacyPassPass(
33+
*PassRegistry::getPassRegistry());
34+
}
35+
36+
// run the SYCLMutatePrintfAddrspace pass on the specified module
37+
bool runOnModule(Module &M) override {
38+
ModuleAnalysisManager MAM;
39+
auto PA = Impl.run(M, MAM);
40+
return !PA.areAllPreserved();
41+
}
42+
43+
private:
44+
SYCLMutatePrintfAddrspacePass Impl;
45+
};
46+
47+
static constexpr unsigned ConstantAddrspaceID = 2;
48+
// If the variadic version gets picked during FE compilation, we'll only have
49+
// 1 function to replace. However, unique declarations are emitted for each
50+
// of the non-variadic (variadic template) calls.
51+
using FunctionVecTy = SmallVector<Function *, 8>;
52+
53+
Function *getCASPrintfFunction(Module &M, PointerType *CASLiteralType);
54+
size_t setFuncCallsOntoCASPrintf(Function *F, Function *CASPrintfFunc,
55+
FunctionVecTy &FunctionsToDrop);
56+
} // namespace
57+
58+
char SYCLMutatePrintfAddrspaceLegacyPass::ID = 0;
59+
INITIALIZE_PASS(SYCLMutatePrintfAddrspaceLegacyPass,
60+
"SYCLMutatePrintfAddrspace",
61+
"Move SYCL printf literal arguments to constant address space",
62+
false, false)
63+
64+
// Public interface to the SYCLMutatePrintfAddrspacePass.
65+
ModulePass *llvm::createSYCLMutatePrintfAddrspaceLegacyPass() {
66+
return new SYCLMutatePrintfAddrspaceLegacyPass();
67+
}
68+
69+
PreservedAnalyses
70+
SYCLMutatePrintfAddrspacePass::run(Module &M, ModuleAnalysisManager &MAM) {
71+
Type *Int8Type = Type::getInt8Ty(M.getContext());
72+
auto *CASLiteralType = PointerType::get(Int8Type, ConstantAddrspaceID);
73+
Function *CASPrintfFunc = getCASPrintfFunction(M, CASLiteralType);
74+
75+
FunctionVecTy FunctionsToDrop;
76+
bool ModuleChanged = false;
77+
for (Function &F : M) {
78+
if (!F.isDeclaration())
79+
continue;
80+
if (!F.getName().startswith("_Z18__spirv_ocl_printf"))
81+
continue;
82+
if (F.getArg(0)->getType() == CASLiteralType)
83+
// No need to replace the literal type and its printf users
84+
continue;
85+
ModuleChanged |=
86+
setFuncCallsOntoCASPrintf(&F, CASPrintfFunc, FunctionsToDrop);
87+
}
88+
for (Function *F : FunctionsToDrop)
89+
F->eraseFromParent();
90+
91+
return ModuleChanged ? PreservedAnalyses::all() : PreservedAnalyses::none();
92+
}
93+
94+
/// Helper implementations
95+
namespace {
96+
97+
/// Get the constant addrspace version of the __spirv_ocl_printf declaration,
98+
/// or generate it if the IR module doesn't have it yet. Also make it
99+
/// variadic so that it could replace all non-variadic generic AS versions.
100+
Function *getCASPrintfFunction(Module &M, PointerType *CASLiteralType) {
101+
Type *Int32Type = Type::getInt32Ty(M.getContext());
102+
auto *CASPrintfFuncTy = FunctionType::get(Int32Type, CASLiteralType,
103+
/*isVarArg=*/true);
104+
// extern int __spirv_ocl_printf(
105+
// const __attribute__((opencl_constant)) char *Format, ...)
106+
FunctionCallee CASPrintfFuncCallee =
107+
M.getOrInsertFunction("_Z18__spirv_ocl_printfPU3AS2Kcz", CASPrintfFuncTy);
108+
auto *CASPrintfFunc = cast<Function>(CASPrintfFuncCallee.getCallee());
109+
CASPrintfFunc->setCallingConv(CallingConv::SPIR_FUNC);
110+
CASPrintfFunc->setDSOLocal(true);
111+
return CASPrintfFunc;
112+
}
113+
114+
/// Generate the constant addrspace version of the generic addrspace-residing
115+
/// global string. If one exists already, get it from the module.
116+
Constant *getCASLiteral(GlobalVariable *GenericASLiteral) {
117+
Module *M = GenericASLiteral->getParent();
118+
// Appending the stable suffix ensures that only one CAS copy is made for each
119+
// string. In case of the matching name, llvm::Module APIs will ensure that
120+
// the existing global is returned.
121+
std::string CASLiteralName = GenericASLiteral->getName().str() + "._AS2";
122+
if (GlobalVariable *ExistingGlobal =
123+
M->getGlobalVariable(CASLiteralName, /*AllowInternal=*/true))
124+
return ExistingGlobal;
125+
126+
StringRef LiteralValue;
127+
getConstantStringInfo(GenericASLiteral, LiteralValue);
128+
IRBuilder<> Builder(M->getContext());
129+
GlobalVariable *Res = Builder.CreateGlobalString(LiteralValue, CASLiteralName,
130+
ConstantAddrspaceID, M);
131+
Res->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
132+
Res->setUnnamedAddr(GlobalValue::UnnamedAddr::None);
133+
return Res;
134+
}
135+
136+
/// Encapsulates the update of CallInst's literal argument.
137+
void setCallArgOntoCASPrintf(CallInst *CI, Constant *CASArg,
138+
Function *CASPrintfFunc) {
139+
CI->setCalledFunction(CASPrintfFunc);
140+
auto *Const = CASArg;
141+
// In case there's a misalignment between the updated function type and
142+
// the constant literal type, create a constant pointer cast so as to
143+
// duck module verifier complaints.
144+
Type *ParamType = CASPrintfFunc->getFunctionType()->getParamType(0);
145+
if (Const->getType() != ParamType)
146+
Const = ConstantExpr::getPointerCast(Const, ParamType);
147+
CI->setArgOperand(0, Const);
148+
}
149+
150+
/// The function's effect is similar to V->stripPointerCastsAndAliases(), but
151+
/// also strips load/store aliases.
152+
/// NB: This function can only operate on simple CFG, where load/store pairs
153+
/// leading to the global variable are merely a consequence of low optimization
154+
/// level. Re-using it for complex CFG with arbitrary memory paths is definitely
155+
/// not recommended.
156+
Value *stripToMemorySource(Value *V) {
157+
Value *MemoryAccess = V;
158+
if (auto *LI = dyn_cast<LoadInst>(MemoryAccess)) {
159+
Value *LoadSource = LI->getPointerOperand();
160+
auto *Store = cast<StoreInst>(*llvm::find_if(
161+
LoadSource->users(), [](User *U) { return isa<StoreInst>(U); }));
162+
MemoryAccess = Store->getValueOperand();
163+
}
164+
return MemoryAccess->stripPointerCastsAndAliases();
165+
}
166+
167+
void emitError(Function *PrintfInstance, CallInst *PrintfCall,
168+
StringRef RecommendationToUser = "") {
169+
std::string ErrorMsg =
170+
std::string("experimental::printf requires format string to reside "
171+
"in constant "
172+
"address space. The compiler wasn't able to "
173+
"automatically convert "
174+
"your format string into constant address space when "
175+
"processing builtin ") +
176+
PrintfInstance->getName().str() + " called in function " +
177+
PrintfCall->getFunction()->getName().str() + ".\n" +
178+
RecommendationToUser.str();
179+
PrintfInstance->getContext().emitError(PrintfCall, ErrorMsg);
180+
}
181+
182+
/// This routine goes over CallInst users of F, resetting the called function
183+
/// to CASPrintfFunc and generating/retracting constant addrspace format
184+
/// strings to use as operands of the mutated calls.
185+
size_t setFuncCallsOntoCASPrintf(Function *F, Function *CASPrintfFunc,
186+
FunctionVecTy &FunctionsToDrop) {
187+
size_t MutatedCallsCount = 0;
188+
SmallVector<std::pair<CallInst *, Constant *>, 16> CallsToMutate;
189+
for (User *U : F->users()) {
190+
if (!isa<CallInst>(U))
191+
continue;
192+
auto *CI = cast<CallInst>(U);
193+
194+
// This key algorithm reaches the global string used as an argument to a
195+
// __spirv_ocl_printf call. It then generates a constant AS copy of that
196+
// global (or gets an existing one). For the return value, the call
197+
// instruction is paired with its future constant addrspace string
198+
// argument.
199+
Value *Stripped = stripToMemorySource(CI->getArgOperand(0));
200+
if (auto *Literal = dyn_cast<GlobalVariable>(Stripped))
201+
CallsToMutate.emplace_back(CI, getCASLiteral(Literal));
202+
else if (auto *Arg = dyn_cast<Argument>(Stripped)) {
203+
// The global literal is passed to __spirv_ocl_printf via a wrapper
204+
// function argument. We'll update the wrapper calls to use the builtin
205+
// function directly instead.
206+
Function *WrapperFunc = Arg->getParent();
207+
std::string BadWrapperErrorMsg =
208+
"Consider simplifying the code by "
209+
"passing format strings directly into experimental::printf calls, "
210+
"avoiding indirection via wrapper function arguments.";
211+
if (!WrapperFunc->getName().contains("6oneapi12experimental6printf")) {
212+
emitError(WrapperFunc, CI, BadWrapperErrorMsg);
213+
return 0;
214+
}
215+
for (User *WrapperU : WrapperFunc->users()) {
216+
auto *WrapperCI = cast<CallInst>(WrapperU);
217+
Value *StrippedArg = stripToMemorySource(WrapperCI->getArgOperand(0));
218+
auto *Literal = dyn_cast<GlobalVariable>(StrippedArg);
219+
// We only expect 1 level of wrappers
220+
if (!Literal) {
221+
emitError(WrapperFunc, WrapperCI, BadWrapperErrorMsg);
222+
return 0;
223+
}
224+
CallsToMutate.emplace_back(WrapperCI, getCASLiteral(Literal));
225+
}
226+
// We're certain that the wrapper won't have any uses, since we've just
227+
// marked all its calls for replacement with __spirv_ocl_printf.
228+
FunctionsToDrop.emplace_back(WrapperFunc);
229+
// Similar certainty for the generic AS version of __spirv_ocl_printf
230+
// itself - we've determined it only gets called inside the
231+
// soon-to-be-removed wrapper.
232+
assert(F->hasOneUse() && "Unexpected __spirv_ocl_printf call outside of "
233+
"SYCL wrapper function");
234+
FunctionsToDrop.emplace_back(F);
235+
} else {
236+
emitError(
237+
F, CI,
238+
"Make sure each format string literal is "
239+
"known at compile time or use OpenCL constant address space literals "
240+
"for device-side printf calls");
241+
return 0;
242+
}
243+
}
244+
for (auto &CallConstantPair : CallsToMutate) {
245+
setCallArgOntoCASPrintf(CallConstantPair.first, CallConstantPair.second,
246+
CASPrintfFunc);
247+
++MutatedCallsCount;
248+
}
249+
if (F->hasNUses(0))
250+
FunctionsToDrop.emplace_back(F);
251+
return MutatedCallsCount;
252+
}
253+
} // namespace
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include <CL/sycl.hpp>
2+
3+
using namespace sycl;
4+
5+
int main() {
6+
queue q;
7+
q.submit([&](handler &cgh) {
8+
cgh.single_task([=]() {
9+
ext::oneapi::experimental::printf("String No. %f\n", 1.0f);
10+
const char *IntFormatString = "String No. %i\n";
11+
ext::oneapi::experimental::printf(IntFormatString, 2);
12+
ext::oneapi::experimental::printf(IntFormatString, 3);
13+
});
14+
});
15+
16+
return 0;
17+
}

0 commit comments

Comments
 (0)