Skip to content

Commit 2ae6715

Browse files
PawelJurekigcbot
authored andcommitted
Builtin function for printf-ing to provided buffer
For assert implementation we need to store the assert message in the assert buffer in a similar manner that current printf implementation uses: UMD provides a buffer with initial write offset set. Offset is then atomically incremented by SIMD threads and used to write per-thread data. The layout of the buffer is a little bit different that for printf buffer. To maximize code reuse, this change adds a builtin function to write to any buffer in a printf fashion: ```llvm int __builtin_IB_printf_to_buffer(global char* buf, global char* currentOffset, int bufSize, ...); ```
1 parent 7a1db0e commit 2ae6715

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/OpenCLPrintf/OpenCLPrintfAnalysis.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ OpenCLPrintfAnalysis::OpenCLPrintfAnalysis() : ModulePass(ID)
4242
const StringRef OpenCLPrintfAnalysis::OPENCL_PRINTF_FUNCTION_NAME = "printf";
4343
const StringRef OpenCLPrintfAnalysis::ONEAPI_PRINTF_FUNCTION_NAME =
4444
"ext::oneapi::experimental::printf";
45+
const StringRef OpenCLPrintfAnalysis::BUILTIN_PRINTF_FUNCTION_NAME =
46+
"__builtin_IB_printf_to_buffer";
4547

4648
bool OpenCLPrintfAnalysis::isOpenCLPrintf(const llvm::Function *F)
4749
{
@@ -54,6 +56,11 @@ bool OpenCLPrintfAnalysis::isOneAPIPrintf(const llvm::Function *F)
5456
return demangledName.find(ONEAPI_PRINTF_FUNCTION_NAME.data()) != std::string::npos;
5557
}
5658

59+
bool OpenCLPrintfAnalysis::isBuiltinPrintf(const llvm::Function* F)
60+
{
61+
return F->getName() == BUILTIN_PRINTF_FUNCTION_NAME;
62+
}
63+
5764
bool OpenCLPrintfAnalysis::runOnModule(Module& M)
5865
{
5966
m_pMDUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
@@ -131,7 +138,8 @@ bool isPrintfOnlyStringConstantImpl(const llvm::Value *v, std::set<const llvm::U
131138
// printf call.
132139
const Function *target = call->getCalledFunction();
133140
res = OpenCLPrintfAnalysis::isOpenCLPrintf(target) ||
134-
OpenCLPrintfAnalysis::isOneAPIPrintf(target);
141+
OpenCLPrintfAnalysis::isOneAPIPrintf(target) ||
142+
OpenCLPrintfAnalysis::isBuiltinPrintf(target);
135143
}
136144
else if (llvm::dyn_cast<llvm::CastInst>(user) ||
137145
llvm::dyn_cast<llvm::SelectInst>(user) ||

IGC/Compiler/Optimizer/OpenCLPasses/OpenCLPrintf/OpenCLPrintfAnalysis.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,11 @@ namespace IGC
5151

5252
static const llvm::StringRef OPENCL_PRINTF_FUNCTION_NAME;
5353
static const llvm::StringRef ONEAPI_PRINTF_FUNCTION_NAME;
54+
static const llvm::StringRef BUILTIN_PRINTF_FUNCTION_NAME;
55+
5456
static bool isOpenCLPrintf(const llvm::Function *F);
5557
static bool isOneAPIPrintf(const llvm::Function *F);
58+
static bool isBuiltinPrintf(const llvm::Function* F);
5659

5760
// Return true if every top level user of a string literal is a printf
5861
// call. Note that the function is expected to work only before printf

IGC/Compiler/Optimizer/OpenCLPasses/OpenCLPrintf/OpenCLPrintfResolution.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ void OpenCLPrintfResolution::visitCallInst(CallInst& callInst)
156156
}
157157

158158
StringRef funcName = callInst.getCalledFunction()->getName();
159-
if (funcName == OpenCLPrintfAnalysis::OPENCL_PRINTF_FUNCTION_NAME)
160-
{
159+
if (funcName == OpenCLPrintfAnalysis::OPENCL_PRINTF_FUNCTION_NAME ||
160+
funcName == OpenCLPrintfAnalysis::BUILTIN_PRINTF_FUNCTION_NAME) {
161161
m_printfCalls.push_back(&callInst);
162162
}
163163
}
@@ -358,12 +358,18 @@ std::string OpenCLPrintfResolution::getPrintfStringsMDNodeName(Function& F)
358358
return "printf.strings";
359359
}
360360

361-
static StoreInst* genStoreInternal(Value* Val, Value* Ptr, BasicBlock* InsertAtEnd, DebugLoc DL)
361+
static StoreInst* genStoreInternal(Value* Val, Value* Ptr, BasicBlock* InsertAtEnd, DebugLoc DL, bool isNontemporal)
362362
{
363363
bool isVolatile = false;
364364
unsigned Align = 4;
365365
auto SI = new llvm::StoreInst(Val, Ptr, isVolatile, IGCLLVM::getCorrectAlign(Align), InsertAtEnd);
366366
SI->setDebugLoc(DL);
367+
if (isNontemporal) {
368+
Constant *One = ConstantInt::get(Type::getInt32Ty(SI->getContext()), 1);
369+
MDNode *Node =
370+
MDNode::get(SI->getContext(), ConstantAsMetadata::get(One));
371+
SI->setMetadata(LLVMContext::MD_nontemporal, Node);
372+
}
367373
return SI;
368374
}
369375

@@ -419,10 +425,23 @@ void OpenCLPrintfResolution::expandPrintfCall(CallInst& printfCall, Function& F)
419425
// printf returns -1 if failed |
420426
return_val = -1; /
421427
}
428+
429+
We also support printf to any provided buffer.
430+
This is done with special builtin with following signature:
431+
int __builtin_IB_printf_to_buffer(global char* buf, global char* currentOffset, int bufSize, ...);
432+
buf - pointer to the begging of the buffer.
433+
currentOffset - pointer to the location with the current offset that will be atomically incremented.
434+
In the case of regular printf this offset is on the first DWORD of printfBuffer.
435+
E.g. in assert buffer it is on the third DWORD.
436+
bufSize - total size of the buffer.
437+
Note: in the case of builtin printf, all the stores will be nontemporal.
438+
439+
422440
----------------------------------------------------------------------
423441
*/
424442
MetaDataUtils* MdUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
425443
ImplicitArgs implicitArgs(F, MdUtils);
444+
bool isPrintfBuiltin = OpenCLPrintfAnalysis::isBuiltinPrintf(printfCall.getCalledFunction());
426445

427446
BasicBlock* currentBBlock = printfCall.getParent();
428447

@@ -431,10 +450,14 @@ void OpenCLPrintfResolution::expandPrintfCall(CallInst& printfCall, Function& F)
431450
preprocessPrintfArgs(printfCall);
432451

433452
// writeOffset = atomic_add(bufferPtr, dataSize)
434-
Value* basebufferPtr = implicitArgs.getImplicitArgValue(F, ImplicitArg::PRINTF_BUFFER, MdUtils);
453+
Value *basebufferPtr = isPrintfBuiltin
454+
? printfCall.getArgOperand(0)
455+
: implicitArgs.getImplicitArgValue(
456+
F, ImplicitArg::PRINTF_BUFFER, MdUtils);
435457

436458
Value* dataSizeVal = ConstantInt::get(m_int32Type, getTotalDataSize());
437-
Instruction* writeOffsetStart = genAtomicAdd(basebufferPtr, dataSizeVal, printfCall, "write_offset");
459+
Value* currentOffsetPtr = isPrintfBuiltin ? printfCall.getArgOperand(1) : basebufferPtr;
460+
Instruction* writeOffsetStart = genAtomicAdd(currentOffsetPtr, dataSizeVal, printfCall, "write_offset");
438461
writeOffsetStart->setDebugLoc(m_DL);
439462

440463
Instruction* writeOffset = writeOffsetStart;
@@ -444,7 +467,11 @@ void OpenCLPrintfResolution::expandPrintfCall(CallInst& printfCall, Function& F)
444467
Instruction* endOffset = BinaryOperator::CreateAdd(writeOffset, dataSizeVal, "end_offset", &printfCall);
445468
endOffset->setDebugLoc(m_DL);
446469

447-
Value* bufferMaxSize = ConstantInt::get(m_int32Type, m_CGContext->m_DriverInfo.getPrintfBufferSize());
470+
Value *bufferMaxSize =
471+
isPrintfBuiltin
472+
? printfCall.getArgOperand(2)
473+
: ConstantInt::get(m_int32Type,
474+
m_CGContext->m_DriverInfo.getPrintfBufferSize());
448475

449476
// write_ptr = buffer_ptr + write_offset;
450477
if (m_ptrSizeIntType != writeOffset->getType())
@@ -509,7 +536,7 @@ void OpenCLPrintfResolution::expandPrintfCall(CallInst& printfCall, Function& F)
509536
writeOffsetPtr = CastInst::Create(Instruction::CastOps::IntToPtr, writeOffset,
510537
m_int32Type->getPointerTo(ADDRESS_SPACE_GLOBAL), "write_offset_ptr", bblockTrue);
511538
writeOffsetPtr->setDebugLoc(m_DL);
512-
genStoreInternal(argTypeVal, writeOffsetPtr, bblockTrue, m_DL);
539+
genStoreInternal(argTypeVal, writeOffsetPtr, bblockTrue, m_DL, isPrintfBuiltin);
513540

514541
// write_offset += 4
515542
writeOffset = BinaryOperator::CreateAdd(writeOffset, constVal4, "write_offset", bblockTrue);
@@ -521,7 +548,7 @@ void OpenCLPrintfResolution::expandPrintfCall(CallInst& printfCall, Function& F)
521548
writeOffsetPtr = CastInst::Create(Instruction::CastOps::IntToPtr, writeOffset,
522549
m_int32Type->getPointerTo(ADDRESS_SPACE_GLOBAL), "write_offset_ptr", bblockTrue);
523550
writeOffsetPtr->setDebugLoc(m_DL);
524-
genStoreInternal(vecSizeVal, writeOffsetPtr, bblockTrue, m_DL);
551+
genStoreInternal(vecSizeVal, writeOffsetPtr, bblockTrue, m_DL, isPrintfBuiltin);
525552

526553
// write_offset += 4
527554
writeOffset = BinaryOperator::CreateAdd(writeOffset, constVal4, "write_offset", bblockTrue);
@@ -542,7 +569,7 @@ void OpenCLPrintfResolution::expandPrintfCall(CallInst& printfCall, Function& F)
542569
}
543570

544571
// *write_offset = argument[i].value
545-
genStoreInternal(printfArg, writeOffsetPtr, bblockTrue, m_DL);
572+
genStoreInternal(printfArg, writeOffsetPtr, bblockTrue, m_DL, isPrintfBuiltin);
546573

547574
// write_offset += argument[i].size
548575
Value* offsetInc = ConstantInt::get(m_ptrSizeIntType, getArgTypeSize(dataType, argDesc->vecSize));
@@ -582,7 +609,7 @@ void OpenCLPrintfResolution::expandPrintfCall(CallInst& printfCall, Function& F)
582609
"write_offset_ptr",
583610
bblockErrorString);
584611
writeOffsetPtr->setDebugLoc(m_DL);
585-
genStoreInternal(constValErrStringIdx, writeOffsetPtr, bblockErrorString, m_DL);
612+
genStoreInternal(constValErrStringIdx, writeOffsetPtr, bblockErrorString, m_DL, isPrintfBuiltin);
586613
brInst = BranchInst::Create(bblockFalseJoin, bblockErrorString);
587614
brInst->setDebugLoc(m_DL);
588615

@@ -680,7 +707,13 @@ Value* OpenCLPrintfResolution::fixupPrintfArg(CallInst& printfCall, Value* arg,
680707

681708
void OpenCLPrintfResolution::preprocessPrintfArgs(CallInst& printfCall)
682709
{
683-
for (int i = 0, numArgs = IGCLLVM::getNumArgOperands(&printfCall); i < numArgs; ++i)
710+
int i = 0;
711+
if (OpenCLPrintfAnalysis::isBuiltinPrintf(printfCall.getCalledFunction())) {
712+
// printf builtin function has buffer pointer, current offset pointer and buffer size as first three arguments.
713+
// Skip them here, as we want to collect the arguments starting from format string.
714+
i = 3;
715+
}
716+
for (int numArgs = IGCLLVM::getNumArgOperands(&printfCall); i < numArgs; ++i)
684717
{
685718
Value* arg = printfCall.getOperand(i);
686719
Type* argType = arg->getType();

0 commit comments

Comments
 (0)