Skip to content

Commit 37c9e8b

Browse files
dmitryryinteligcbot
authored andcommitted
support prinf with args in GenXPrintfResolution
1 parent f1d9caf commit 37c9e8b

File tree

2 files changed

+200
-10
lines changed

2 files changed

+200
-10
lines changed

IGC/VectorCompiler/lib/BiF/printf_ocl_genx.cpp

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,26 +116,88 @@ printf_init_impl(vector<int, ArgsInfoVector::Size> ArgsInfo) {
116116
return generateTransferData(BufferPtr + Offset, BufferSize);
117117
}
118118

119+
// Writes \p Data to printf buffer via \p CurAddress pointer.
120+
// Returns promoted pointer.
121+
static uintptr_t writeElementToBuffer(uintptr_t CurAddress,
122+
BufferElementTy Data) {
123+
vector<uintptr_t, 1> CurAddressVec = CurAddress;
124+
vector<BufferElementTy, 1> DataVec = Data;
125+
svm::scatter(CurAddressVec, DataVec);
126+
return CurAddress + sizeof(Data);
127+
}
128+
119129
// Format string handling. Just writing format string index to buffer and
120130
// promoting the pointer to buffer.
121131
template <typename T>
122132
vector<BufferElementTy, TransferDataSize>
123133
printf_fmt_impl(vector<BufferElementTy, TransferDataSize> TransferData,
124134
T *FormatString) {
125-
vector<uintptr_t, 1> CurAddress = getCurAddress(TransferData);
126-
vector<BufferElementTy, 1> Index = detail::printf_format_index(FormatString);
127-
svm::scatter(CurAddress, Index);
128-
CurAddress += FormatStringAnnotationSize;
129-
setCurAddress(TransferData, CurAddress[0]);
135+
uintptr_t CurAddress = getCurAddress(TransferData);
136+
BufferElementTy Index = detail::printf_format_index(FormatString);
137+
CurAddress = writeElementToBuffer(CurAddress, Index);
138+
setCurAddress(TransferData, CurAddress);
130139
return TransferData;
131140
}
132141

142+
// ArgCode is written into printf buffer before every argument.
143+
namespace ArgCode {
144+
enum Enum {
145+
Invalid,
146+
Byte,
147+
Short,
148+
Int,
149+
Float,
150+
String,
151+
Long,
152+
Pointer,
153+
Double,
154+
VectorByte,
155+
VectorShort,
156+
VectorInt,
157+
VectorLong,
158+
VectorFloat,
159+
VectorDouble,
160+
Size
161+
};
162+
} // namespace ArgCode
163+
164+
namespace ArgInfo {
165+
enum Enum { Code, NumDWords, Size };
166+
} // namespace ArgInfo
167+
168+
static vector<BufferElementTy, ArgInfo::Size> getArgInfo(ArgKind::Enum Kind) {
169+
using RetInitT = cl_vector<BufferElementTy, ArgInfo::Size>;
170+
switch (Kind) {
171+
case ArgKind::Char:
172+
case ArgKind::Short:
173+
case ArgKind::Int:
174+
return RetInitT{ArgCode::Int, 1};
175+
case ArgKind::Long:
176+
return RetInitT{ArgCode::Long, 2};
177+
case ArgKind::Float:
178+
return RetInitT{ArgCode::Float, 1};
179+
case ArgKind::Double:
180+
return RetInitT{ArgCode::Double, 2};
181+
case ArgKind::Pointer:
182+
return RetInitT{ArgCode::Pointer, 2};
183+
case ArgKind::String:
184+
return RetInitT{ArgCode::String, 1};
185+
default:
186+
return RetInitT{ArgCode::Invalid, 0};
187+
}
188+
}
189+
133190
// Single printf arg handling (those that are after format string).
134-
// FIXME: yet unsupported.
135191
static vector<BufferElementTy, TransferDataSize>
136192
printf_arg_impl(vector<BufferElementTy, TransferDataSize> TransferData,
137193
ArgKind::Enum Kind,
138194
vector<BufferElementTy, ArgData::Size> Arg) {
195+
vector<BufferElementTy, ArgInfo::Size> Info = getArgInfo(Kind);
196+
uintptr_t CurAddress = getCurAddress(TransferData);
197+
CurAddress = writeElementToBuffer(CurAddress, Info[ArgInfo::Code]);
198+
for (int Idx = 0; Idx != Info[ArgInfo::NumDWords]; ++Idx)
199+
CurAddress = writeElementToBuffer(CurAddress, Arg[Idx]);
200+
setCurAddress(TransferData, CurAddress);
139201
return TransferData;
140202
}
141203

IGC/VectorCompiler/lib/GenXOpts/CMTrans/GenXPrintfResolution.cpp

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
5252
#include <llvm/ADT/STLExtras.h>
5353
#include <llvm/ADT/iterator_range.h>
5454
#include <llvm/IR/Constants.h>
55+
#include <llvm/IR/DataLayout.h>
5556
#include <llvm/IR/IRBuilder.h>
5657
#include <llvm/IR/InstIterator.h>
5758
#include <llvm/IR/Instructions.h>
@@ -60,6 +61,8 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
6061
#include <llvm/Pass.h>
6162
#include <llvm/Support/ErrorHandling.h>
6263

64+
#include "llvmWrapper/IR/DerivedTypes.h"
65+
6366
#include <algorithm>
6467
#include <functional>
6568
#include <numeric>
@@ -99,7 +102,8 @@ class GenXPrintfResolution final : public ModulePass {
99102
void setAlwaysInlineForPrintfImpl();
100103
CallInst &createPrintfInitCall(CallInst &OrigPrintf);
101104
CallInst &createPrintfFmtCall(CallInst &OrigPrintf, CallInst &InitCall);
102-
CallInst &createPrintfArgCall(CallInst &OrigPrintf, CallInst &PrevCall);
105+
CallInst &createPrintfArgCall(CallInst &OrigPrintf, CallInst &PrevCall,
106+
Value &Arg);
103107
CallInst &createPrintfRetCall(CallInst &OrigPrintf, CallInst &PrevCall);
104108
};
105109
} // namespace
@@ -203,7 +207,7 @@ void GenXPrintfResolution::handlePrintfCall(CallInst &OrigPrintf) {
203207
auto &LastArgCall = *std::accumulate(
204208
std::next(OrigPrintf.arg_begin()), OrigPrintf.arg_end(), &FmtCall,
205209
[&OrigPrintf, this](CallInst *PrevCall, Value *Arg) {
206-
return &createPrintfArgCall(OrigPrintf, *PrevCall);
210+
return &createPrintfArgCall(OrigPrintf, *PrevCall, *Arg);
207211
});
208212
auto &RetCall = createPrintfRetCall(OrigPrintf, LastArgCall);
209213
RetCall.takeName(&OrigPrintf);
@@ -320,10 +324,134 @@ CallInst &GenXPrintfResolution::createPrintfFmtCall(CallInst &OrigPrintf,
320324
OrigPrintf.getName() + ".printf.fmt");
321325
}
322326

327+
static ArgKind::Enum getIntegerArgKind(Type &ArgTy) {
328+
IGC_ASSERT_MESSAGE(ArgTy.isIntegerTy(),
329+
"wrong argument: integer type was expected");
330+
auto BitWidth = ArgTy.getIntegerBitWidth();
331+
switch (BitWidth) {
332+
case 64:
333+
return ArgKind::Long;
334+
case 32:
335+
return ArgKind::Int;
336+
case 16:
337+
return ArgKind::Short;
338+
default:
339+
IGC_ASSERT_MESSAGE(BitWidth == 8, "unexpected integer type");
340+
return ArgKind::Char;
341+
}
342+
}
343+
344+
static ArgKind::Enum getFloatingPointArgKind(Type &ArgTy) {
345+
IGC_ASSERT_MESSAGE(ArgTy.isFloatingPointTy(),
346+
"wrong argument: floating point type was expected");
347+
if (ArgTy.isDoubleTy())
348+
return ArgKind::Double;
349+
// FIXME: what about half?
350+
IGC_ASSERT_MESSAGE(ArgTy.isFloatTy(), "unexpected floating point type");
351+
return ArgKind::Float;
352+
}
353+
354+
static ArgKind::Enum getPointerArgKind(Type &ArgTy) {
355+
IGC_ASSERT_MESSAGE(ArgTy.isPointerTy(),
356+
"wrong argument: pointer type was expected");
357+
if (ArgTy.getPointerElementType()->isIntegerTy(8))
358+
// FIXME: what if we want to print a pointer to a string?
359+
// Seems like it cannot be handled without parsing the format string.
360+
return ArgKind::String;
361+
return ArgKind::Pointer;
362+
}
363+
364+
static ArgKind::Enum getArgKind(Type &ArgTy) {
365+
if (ArgTy.isIntegerTy())
366+
return getIntegerArgKind(ArgTy);
367+
if (ArgTy.isFloatingPointTy())
368+
return getFloatingPointArgKind(ArgTy);
369+
return getPointerArgKind(ArgTy);
370+
}
371+
372+
// sizeof(<2 x i32>) == 64
373+
static constexpr unsigned VecArgSize = 64;
374+
static constexpr auto VecArgElementSize = VecArgSize / ArgData::Size;
375+
376+
// Casts Arg to <2 x i32> vector. For pointers ptrtoint i64 should be generated
377+
// first.
378+
Value &get64BitArgAsVector(Value &Arg, IRBuilder<> &IRB, const DataLayout &DL) {
379+
IGC_ASSERT_MESSAGE(DL.getTypeSizeInBits(Arg.getType()) == 64,
380+
"64-bit argument was expected");
381+
auto *VecArgTy =
382+
IGCLLVM::FixedVectorType::get(IRB.getInt32Ty(), ArgData::Size);
383+
Value *ArgToBitCast = &Arg;
384+
if (Arg.getType()->isPointerTy())
385+
ArgToBitCast =
386+
IRB.CreatePtrToInt(&Arg, IRB.getInt64Ty(), Arg.getName() + ".arg.p2i");
387+
return *IRB.CreateBitCast(ArgToBitCast, VecArgTy, Arg.getName() + ".arg.bc");
388+
}
389+
390+
// Just creates this instruction:
391+
// insertelement <2 x i32> zeroinitializer, i32 %arg, i32 0
392+
// \p Arg must be i32 type.
393+
Value &get32BitIntArgAsVector(Value &Arg, IRBuilder<> &IRB,
394+
const DataLayout &DL) {
395+
IGC_ASSERT_MESSAGE(Arg.getType()->isIntegerTy(32),
396+
"i32 argument was expected");
397+
auto *VecArgTy =
398+
IGCLLVM::FixedVectorType::get(IRB.getInt32Ty(), ArgData::Size);
399+
auto *BlankVec = ConstantAggregateZero::get(VecArgTy);
400+
return *IRB.CreateInsertElement(BlankVec, &Arg, IRB.getInt32(0),
401+
Arg.getName() + ".arg.insert");
402+
}
403+
404+
// Takes arg that is not greater than 32 bit and casts it to i32 with possible
405+
// zero extension.
406+
static Value &getArgAs32BitInt(Value &Arg, IRBuilder<> &IRB,
407+
const DataLayout &DL) {
408+
auto ArgSize = DL.getTypeSizeInBits(Arg.getType());
409+
IGC_ASSERT_MESSAGE(ArgSize <= VecArgElementSize,
410+
"argument isn't expected to be greater than 32 bit");
411+
if (ArgSize < VecArgElementSize) {
412+
// FIXME: seems like there may be some problems with signed types, depending
413+
// on our BiF and runtime implementation.
414+
// FIXME: What about half?
415+
IGC_ASSERT_MESSAGE(Arg.getType()->isIntegerTy(),
416+
"only integers are expected to be less than 32 bits");
417+
return *IRB.CreateZExt(&Arg, IRB.getInt32Ty(), Arg.getName() + ".arg.zext");
418+
}
419+
if (Arg.getType()->isPointerTy())
420+
return *IRB.CreatePtrToInt(&Arg, IRB.getInt32Ty(),
421+
Arg.getName() + ".arg.p2i");
422+
if (!Arg.getType()->isIntegerTy())
423+
return *IRB.CreateBitCast(&Arg, IRB.getInt32Ty(),
424+
Arg.getName() + ".arg.bc");
425+
return Arg;
426+
}
427+
428+
// Args are passed via <2 x i32> vector. This function casts \p Arg to this
429+
// vector type. \p Arg is zext if necessary (zext in common sense - writing
430+
// top element of a vector with zeros is zero extending too).
431+
static Value &getArgAsVector(Value &Arg, IRBuilder<> &IRB,
432+
const DataLayout &DL) {
433+
IGC_ASSERT_MESSAGE(!isa<IGCLLVM::FixedVectorType>(Arg.getType()),
434+
"scalar type is expected");
435+
auto ArgSize = DL.getTypeSizeInBits(Arg.getType());
436+
437+
if (ArgSize == VecArgSize)
438+
return get64BitArgAsVector(Arg, IRB, DL);
439+
IGC_ASSERT_MESSAGE(ArgSize < VecArgSize,
440+
"arg is expected to be not greater than 64 bit");
441+
Value &Arg32Bit = getArgAs32BitInt(Arg, IRB, DL);
442+
return get32BitIntArgAsVector(Arg32Bit, IRB, DL);
443+
}
444+
323445
CallInst &GenXPrintfResolution::createPrintfArgCall(CallInst &OrigPrintf,
324-
CallInst &PrevCall) {
446+
CallInst &PrevCall,
447+
Value &Arg) {
325448
assertPrintfCall(OrigPrintf);
326-
return PrevCall;
449+
ArgKind::Enum Kind = getArgKind(*Arg.getType());
450+
IRBuilder<> IRB{&OrigPrintf};
451+
Value &ArgVec = getArgAsVector(Arg, IRB, *DL);
452+
return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Arg],
453+
{&PrevCall, IRB.getInt32(Kind), &ArgVec},
454+
OrigPrintf.getName() + ".printf.arg");
327455
}
328456

329457
CallInst &GenXPrintfResolution::createPrintfRetCall(CallInst &OrigPrintf,

0 commit comments

Comments
 (0)