Skip to content

Commit a4a0d28

Browse files
[WIP] [SPIR-V] Do not rely on type metadata for ptr element type resolution
1 parent 03203b7 commit a4a0d28

File tree

6 files changed

+176
-32
lines changed

6 files changed

+176
-32
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2043,7 +2043,8 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
20432043
.addImm(Builtin->Number);
20442044
for (auto Argument : Call->Arguments)
20452045
MIB.addUse(Argument);
2046-
MIB.addImm(Builtin->ElementCount);
2046+
if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
2047+
MIB.addImm(Builtin->ElementCount);
20472048

20482049
// Rounding mode should be passed as a last argument in the MI for builtins
20492050
// like "vstorea_halfn_r".
@@ -2179,6 +2180,74 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
21792180
return false;
21802181
}
21812182

2183+
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2184+
unsigned ArgIdx, LLVMContext &Ctx) {
2185+
SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
2186+
StringRef BuiltinArgs =
2187+
DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
2188+
BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
2189+
StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
2190+
assert(ArgIdx < BuiltinArgsTypeStrs.size() && "Out of bounds argument index");
2191+
bool IsBaseType = TypeStr.ends_with("*") || TypeStr.ends_with("_t") ||
2192+
TypeStr.starts_with("ocl_");
2193+
assert(IsBaseType && "Parsing only ptr element type/builtin base type");
2194+
2195+
// Parse type name in either "typeN" or "type vector[N]" format, where
2196+
// N is the number of elements of the vector.
2197+
Type *BaseType;
2198+
unsigned VecElts = 0;
2199+
2200+
TypeStr.consume_front("atomic_");
2201+
2202+
if (TypeStr.starts_with("void")) {
2203+
BaseType = Type::getVoidTy(Ctx);
2204+
TypeStr = TypeStr.substr(strlen("void"));
2205+
} else if (TypeStr.starts_with("bool")) {
2206+
BaseType = Type::getIntNTy(Ctx, 1);
2207+
TypeStr = TypeStr.substr(strlen("bool"));
2208+
} else if (TypeStr.starts_with("char") || TypeStr.starts_with("uchar")) {
2209+
BaseType = Type::getInt8Ty(Ctx);
2210+
TypeStr = TypeStr.starts_with("char") ? TypeStr.substr(strlen("char"))
2211+
: TypeStr.substr(strlen("uchar"));
2212+
} else if (TypeStr.starts_with("short") || TypeStr.starts_with("ushort")) {
2213+
BaseType = Type::getInt16Ty(Ctx);
2214+
TypeStr = TypeStr.starts_with("short") ? TypeStr.substr(strlen("short"))
2215+
: TypeStr.substr(strlen("ushort"));
2216+
} else if (TypeStr.starts_with("int") || TypeStr.starts_with("uint")) {
2217+
BaseType = Type::getInt32Ty(Ctx);
2218+
TypeStr = TypeStr.starts_with("int") ? TypeStr.substr(strlen("int"))
2219+
: TypeStr.substr(strlen("uint"));
2220+
} else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) {
2221+
BaseType = Type::getInt64Ty(Ctx);
2222+
TypeStr = TypeStr.starts_with("long") ? TypeStr.substr(strlen("long"))
2223+
: TypeStr.substr(strlen("ulong"));
2224+
} else if (TypeStr.starts_with("half")) {
2225+
BaseType = Type::getHalfTy(Ctx);
2226+
TypeStr = TypeStr.substr(strlen("half"));
2227+
} else if (TypeStr.starts_with("float")) {
2228+
BaseType = Type::getFloatTy(Ctx);
2229+
TypeStr = TypeStr.substr(strlen("float"));
2230+
} else if (TypeStr.starts_with("double")) {
2231+
BaseType = Type::getDoubleTy(Ctx);
2232+
TypeStr = TypeStr.substr(strlen("double"));
2233+
} else {
2234+
// Unable to recognize SPIRV type name
2235+
return nullptr;
2236+
}
2237+
2238+
// Handle "typeN*" or "type vector[N]*".
2239+
bool IsPtrToVec = TypeStr.consume_back("*");
2240+
2241+
if (TypeStr.consume_front(" vector["))
2242+
TypeStr = TypeStr.substr(0, TypeStr.find(']'));
2243+
2244+
TypeStr.getAsInteger(10, VecElts);
2245+
if (VecElts > 0)
2246+
BaseType = VectorType::get(BaseType, VecElts, false);
2247+
2248+
return BaseType;
2249+
}
2250+
21822251
struct BuiltinType {
21832252
StringRef Name;
21842253
uint32_t Opcode;

llvm/lib/Target/SPIRV/SPIRVBuiltins.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
3838
const SmallVectorImpl<Register> &Args,
3939
SPIRVGlobalRegistry *GR);
4040

41+
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
42+
unsigned ArgIdx, LLVMContext &Ctx);
43+
4144
/// Translates a string representing a SPIR-V or OpenCL builtin type to a
4245
/// TargetExtType that can be further lowered with lowerBuiltinType().
4346
///

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "SPIRV.h"
15+
#include "SPIRVBuiltins.h"
1516
#include "SPIRVMetadata.h"
1617
#include "SPIRVTargetMachine.h"
1718
#include "SPIRVUtils.h"
@@ -75,6 +76,9 @@ class SPIRVEmitIntrinsics
7576
void processInstrAfterVisit(Instruction *I);
7677
void insertAssignPtrTypeIntrs(Instruction *I);
7778
void insertAssignTypeIntrs(Instruction *I);
79+
void replacePointerOperandWithPtrCast(Instruction *I, Value *Pointer,
80+
Type *ExpectedElementType,
81+
unsigned OperandToReplace);
7882
void insertPtrCastInstr(Instruction *I);
7983
void processGlobalValue(GlobalVariable &GV);
8084

@@ -286,34 +290,9 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
286290
return NewI;
287291
}
288292

289-
void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
290-
Value *Pointer;
291-
Type *ExpectedElementType;
292-
unsigned OperandToReplace;
293-
294-
StoreInst *SI = dyn_cast<StoreInst>(I);
295-
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
296-
SI->getValueOperand()->getType()->isPointerTy() &&
297-
isa<Argument>(SI->getValueOperand())) {
298-
Pointer = SI->getValueOperand();
299-
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
300-
OperandToReplace = 0;
301-
} else if (SI) {
302-
Pointer = SI->getPointerOperand();
303-
ExpectedElementType = SI->getValueOperand()->getType();
304-
OperandToReplace = 1;
305-
} else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
306-
Pointer = LI->getPointerOperand();
307-
ExpectedElementType = LI->getType();
308-
OperandToReplace = 0;
309-
} else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
310-
Pointer = GEPI->getPointerOperand();
311-
ExpectedElementType = GEPI->getSourceElementType();
312-
OperandToReplace = 0;
313-
} else {
314-
return;
315-
}
316-
293+
void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
294+
Instruction *I, Value *Pointer, Type *ExpectedElementType,
295+
unsigned OperandToReplace) {
317296
// If Pointer is the result of nop BitCastInst (ptr -> ptr), use the source
318297
// pointer instead. The BitCastInst should be later removed when visited.
319298
while (BitCastInst *BC = dyn_cast<BitCastInst>(Pointer))
@@ -413,6 +392,45 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
413392
}
414393
}
415394

395+
void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
396+
// Handle basic instructions:
397+
StoreInst *SI = dyn_cast<StoreInst>(I);
398+
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
399+
SI->getValueOperand()->getType()->isPointerTy() &&
400+
isa<Argument>(SI->getValueOperand())) {
401+
return replacePointerOperandWithPtrCast(
402+
I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0);
403+
} else if (SI) {
404+
return replacePointerOperandWithPtrCast(
405+
I, SI->getPointerOperand(), SI->getValueOperand()->getType(), 1);
406+
} else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
407+
return replacePointerOperandWithPtrCast(I, LI->getPointerOperand(),
408+
LI->getType(), 0);
409+
} else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
410+
return replacePointerOperandWithPtrCast(I, GEPI->getPointerOperand(),
411+
GEPI->getSourceElementType(), 0);
412+
}
413+
414+
// Handle calls to builtins (non-intrinsics):
415+
CallInst *CI = dyn_cast<CallInst>(I);
416+
if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic())
417+
return;
418+
419+
std::string DemangledName =
420+
getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
421+
if (DemangledName.empty())
422+
return;
423+
424+
for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
425+
if (!isa<PointerType>(CI->getArgOperand(OpIdx)->getType()))
426+
continue;
427+
Type *ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType(
428+
DemangledName, OpIdx, I->getContext());
429+
replacePointerOperandWithPtrCast(CI, CI->getArgOperand(OpIdx), ExpectedType,
430+
OpIdx);
431+
}
432+
}
433+
416434
Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
417435
SmallVector<Type *, 4> Types = {I.getType(), I.getOperand(0)->getType(),
418436
I.getOperand(1)->getType(),

llvm/test/CodeGen/SPIRV/half_no_extension.ll

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
99

10-
; TODO(#60133): Requires updates following opaque pointer migration.
11-
; XFAIL: *
12-
1310
; CHECK-SPIRV: OpCapability Float16Buffer
1411
; CHECK-SPIRV-NOT: OpCapability Float16
1512

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; This test only intends to check the vstoren builtin name resolution.
3+
; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.
4+
5+
; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"
6+
7+
; CHECK-DAG: %[[#VOID:]] = OpTypeVoid
8+
; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
9+
; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
10+
; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
11+
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
12+
13+
; CHECK: %[[#DATA:]] = OpFunctionParameter %[[#VINT8]]
14+
; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
15+
; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]
16+
17+
define spir_kernel void @test_fn(<2 x i8> %data, i64 %offset, ptr addrspace(1) %address) {
18+
; CHECK: %[[#]] = OpExtInst %[[#VOID]] %[[#IMPORT]] vstoren %[[#DATA]] %[[#OFFSET]] %[[#ADDRESS]]
19+
call spir_func void @_Z7vstore2Dv2_cmPU3AS1c(<2 x i8> %data, i64 %offset, ptr addrspace(1) %address)
20+
ret void
21+
}
22+
23+
declare spir_func void @_Z7vstore2Dv2_cmPU3AS1c(<2 x i8>, i64, ptr addrspace(1))
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
3+
; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
4+
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
5+
6+
define spir_kernel void @test_fn(ptr addrspace(1) %src, ptr addrspace(1) %dummy) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_type_qual !4 !kernel_arg_base_type !3 {
7+
entry:
8+
%g1 = call spir_func i64 @_Z13get_global_idj(i32 0)
9+
%i1 = insertelement <3 x i64> undef, i64 %g1, i32 0
10+
%g2 = call spir_func i64 @_Z13get_global_idj(i32 1)
11+
%i2 = insertelement <3 x i64> %i1, i64 %g2, i32 1
12+
%g3 = call spir_func i64 @_Z13get_global_idj(i32 2)
13+
%i3 = insertelement <3 x i64> %i2, i64 %g3, i32 2
14+
%e = extractelement <3 x i64> %i3, i32 0
15+
%c1 = trunc i64 %e to i32
16+
%c2 = sext i32 %c1 to i64
17+
%b = bitcast ptr addrspace(1) %src to ptr addrspace(1)
18+
19+
; Make sure that builtin call directly uses either a OpBitcast or OpFunctionParameter of i8* type
20+
; CHECK: %[[#BITCASTorPARAMETER:]] = {{.*}} %[[#PTRINT8]] {{.*}}
21+
; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] vloadn %[[#]] %[[#BITCASTorPARAMETER]] 3
22+
%call = call spir_func <3 x i8> @_Z6vload3mPU3AS1Kc(i64 %c2, ptr addrspace(1) %b)
23+
24+
ret void
25+
}
26+
27+
declare spir_func i64 @_Z13get_global_idj(i32)
28+
29+
declare spir_func <3 x i8> @_Z6vload3mPU3AS1Kc(i64, ptr addrspace(1))
30+
31+
!1 = !{i32 1, i32 1}
32+
!2 = !{!"none", !"none"}
33+
!3 = !{!"char3*", !"char*"}
34+
!4 = !{!"", !""}

0 commit comments

Comments
 (0)