Skip to content

Commit 7dacb7c

Browse files
authored
SPV_KHR_untyped_pointers - implement OpTypeUntypedPointerKHR (#2687)
This is the first part of the extension implementation. Introducing untyped pointer type. Spec: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_untyped_pointers.html
1 parent 2b5f15d commit 7dacb7c

19 files changed

+372
-52
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ EXT(SPV_KHR_subgroup_rotate)
1717
EXT(SPV_KHR_non_semantic_info)
1818
EXT(SPV_KHR_shader_clock)
1919
EXT(SPV_KHR_cooperative_matrix)
20+
EXT(SPV_KHR_untyped_pointers)
2021
EXT(SPV_INTEL_subgroups)
2122
EXT(SPV_INTEL_media_block_io)
2223
EXT(SPV_INTEL_device_side_avc_motion_estimation)

lib/SPIRV/SPIRVReader.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
355355
return TypedPointerType::get(ElementTy, AS);
356356
return mapType(T, PointerType::get(ElementTy, AS));
357357
}
358+
case OpTypeUntypedPointerKHR: {
359+
const unsigned AS =
360+
SPIRSPIRVAddrSpaceMap::rmap(T->getPointerStorageClass());
361+
return mapType(T, PointerType::get(*Context, AS));
362+
}
358363
case OpTypeVector:
359364
return mapType(T,
360365
FixedVectorType::get(transType(T->getVectorComponentType()),
@@ -558,6 +563,8 @@ std::string SPIRVToLLVM::transTypeToOCLTypeName(SPIRVType *T, bool IsSigned) {
558563
}
559564
return transTypeToOCLTypeName(ET) + "*";
560565
}
566+
case OpTypeUntypedPointerKHR:
567+
return "int*";
561568
case OpTypeVector:
562569
return transTypeToOCLTypeName(T->getVectorComponentType()) +
563570
T->getVectorComponentCount();

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,8 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
413413
// A pointer to image or pipe type in LLVM is translated to a SPIRV
414414
// (non-pointer) image or pipe type.
415415
if (T->isPointerTy()) {
416-
auto *ET = Type::getInt8Ty(T->getContext());
417416
auto AddrSpc = T->getPointerAddressSpace();
417+
auto *ET = Type::getInt8Ty(T->getContext());
418418
return transPointerType(ET, AddrSpc);
419419
}
420420

@@ -716,7 +716,6 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
716716
transType(ET)));
717717
}
718718
} else {
719-
SPIRVType *ElementType = transType(ET);
720719
// ET, as a recursive type, may contain exactly the same pointer T, so it
721720
// may happen that after translation of ET we already have translated T,
722721
// added the translated pointer to the SPIR-V module and mapped T to this
@@ -725,7 +724,17 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
725724
if (Loc != PointeeTypeMap.end()) {
726725
return Loc->second;
727726
}
728-
SPIRVType *TranslatedTy = transPointerType(ElementType, AddrSpc);
727+
728+
SPIRVType *ElementType = nullptr;
729+
SPIRVType *TranslatedTy = nullptr;
730+
if (ET->isPointerTy() &&
731+
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers)) {
732+
TranslatedTy = BM->addUntypedPointerKHRType(
733+
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
734+
} else {
735+
ElementType = transType(ET);
736+
TranslatedTy = transPointerType(ElementType, AddrSpc);
737+
}
729738
PointeeTypeMap[TypeKey] = TranslatedTy;
730739
return TranslatedTy;
731740
}
@@ -740,8 +749,16 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
740749
if (Loc != PointeeTypeMap.end())
741750
return Loc->second;
742751

743-
SPIRVType *TranslatedTy = BM->addPointerType(
744-
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)), ET);
752+
SPIRVType *TranslatedTy = nullptr;
753+
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) &&
754+
!(ET->isTypeArray() || ET->isTypeVector() || ET->isTypeImage() ||
755+
ET->isTypeSampler() || ET->isTypePipe())) {
756+
TranslatedTy = BM->addUntypedPointerKHRType(
757+
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
758+
} else {
759+
TranslatedTy = BM->addPointerType(
760+
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)), ET);
761+
}
745762
PointeeTypeMap[TypeKey] = TranslatedTy;
746763
return TranslatedTy;
747764
}
@@ -2176,8 +2193,13 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
21762193
MemoryAccessNoAliasINTELMaskMask);
21772194
if (MemoryAccess.front() == 0)
21782195
MemoryAccess.clear();
2179-
return mapValue(V, BM->addLoadInst(transValue(LD->getPointerOperand(), BB),
2180-
MemoryAccess, BB));
2196+
return mapValue(
2197+
V,
2198+
BM->addLoadInst(
2199+
transValue(LD->getPointerOperand(), BB), MemoryAccess, BB,
2200+
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers)
2201+
? transType(LD->getType())
2202+
: nullptr));
21812203
}
21822204

21832205
if (BinaryOperator *B = dyn_cast<BinaryOperator>(V)) {
@@ -2387,14 +2409,17 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
23872409

23882410
if (auto *Phi = dyn_cast<PHINode>(V)) {
23892411
std::vector<SPIRVValue *> IncomingPairs;
2412+
SPIRVType *Ty = transScavengedType(Phi);
23902413

23912414
for (size_t I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) {
2392-
IncomingPairs.push_back(transValue(Phi->getIncomingValue(I), BB, true,
2393-
FuncTransMode::Pointer));
2415+
SPIRVValue *Val = transValue(Phi->getIncomingValue(I), BB, true,
2416+
FuncTransMode::Pointer);
2417+
if (Val->getType() != Ty)
2418+
Val = BM->addUnaryInst(OpBitcast, Ty, Val, BB);
2419+
IncomingPairs.push_back(Val);
23942420
IncomingPairs.push_back(transValue(Phi->getIncomingBlock(I), nullptr));
23952421
}
2396-
return mapValue(V,
2397-
BM->addPhiInst(transScavengedType(Phi), IncomingPairs, BB));
2422+
return mapValue(V, BM->addPhiInst(Ty, IncomingPairs, BB));
23982423
}
23992424

24002425
if (auto *Ext = dyn_cast<ExtractValueInst>(V)) {
@@ -6650,9 +6675,12 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
66506675
assert((Pointee == Args[I] || !isa<Function>(Pointee)) &&
66516676
"Illegal use of a function pointer type");
66526677
}
6653-
SPArgs.push_back(SPI->isOperandLiteral(I)
6654-
? cast<ConstantInt>(Args[I])->getZExtValue()
6655-
: transValue(Args[I], BB)->getId());
6678+
if (!SPI->isOperandLiteral(I)) {
6679+
SPIRVValue *Val = transValue(Args[I], BB);
6680+
SPArgs.push_back(Val->getId());
6681+
} else {
6682+
SPArgs.push_back(cast<ConstantInt>(Args[I])->getZExtValue());
6683+
}
66566684
}
66576685
BM->addInstTemplate(SPI, SPArgs, BB, SPRetTy);
66586686
if (!SPRetTy || !SPRetTy->isTypeStruct())

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,15 @@ class SPIRVStore : public SPIRVInstruction, public SPIRVMemoryAccess {
577577
SPIRVInstruction::validate();
578578
if (getSrc()->isForward() || getDst()->isForward())
579579
return;
580-
assert(getValueType(PtrId)->getPointerElementType() ==
581-
getValueType(ValId) &&
582-
"Inconsistent operand types");
580+
assert(
581+
(getValueType(PtrId)
582+
->getPointerElementType()
583+
->isTypeUntypedPointerKHR() ||
584+
// TODO: This check should be removed once we support untyped
585+
// variables.
586+
getValueType(ValId)->isTypeUntypedPointerKHR() ||
587+
getValueType(PtrId)->getPointerElementType() == getValueType(ValId)) &&
588+
"Inconsistent operand types");
583589
}
584590

585591
private:
@@ -594,11 +600,12 @@ class SPIRVLoad : public SPIRVInstruction, public SPIRVMemoryAccess {
594600
// Complete constructor
595601
SPIRVLoad(SPIRVId TheId, SPIRVId PointerId,
596602
const std::vector<SPIRVWord> &TheMemoryAccess,
597-
SPIRVBasicBlock *TheBB)
603+
SPIRVBasicBlock *TheBB, SPIRVType *TheType = nullptr)
598604
: SPIRVInstruction(
599605
FixedWords + TheMemoryAccess.size(), OpLoad,
600-
TheBB->getValueType(PointerId)->getPointerElementType(), TheId,
601-
TheBB),
606+
TheType ? TheType
607+
: TheBB->getValueType(PointerId)->getPointerElementType(),
608+
TheId, TheBB),
602609
SPIRVMemoryAccess(TheMemoryAccess), PtrId(PointerId),
603610
MemoryAccess(TheMemoryAccess) {
604611
validate();
@@ -628,6 +635,12 @@ class SPIRVLoad : public SPIRVInstruction, public SPIRVMemoryAccess {
628635
void validate() const override {
629636
SPIRVInstruction::validate();
630637
assert((getValue(PtrId)->isForward() ||
638+
getValueType(PtrId)
639+
->getPointerElementType()
640+
->isTypeUntypedPointerKHR() ||
641+
// TODO: This check should be removed once we support untyped
642+
// variables.
643+
Type->isTypeUntypedPointerKHR() ||
631644
Type == getValueType(PtrId)->getPointerElementType()) &&
632645
"Inconsistent types");
633646
}
@@ -2010,7 +2023,8 @@ class SPIRVCompositeExtractBase : public SPIRVInstTemplateBase {
20102023
(void)Composite;
20112024
assert(getValueType(Composite)->isTypeArray() ||
20122025
getValueType(Composite)->isTypeStruct() ||
2013-
getValueType(Composite)->isTypeVector());
2026+
getValueType(Composite)->isTypeVector() ||
2027+
getValueType(Composite)->isTypeUntypedPointerKHR());
20142028
}
20152029
};
20162030

@@ -2036,7 +2050,8 @@ class SPIRVCompositeInsertBase : public SPIRVInstTemplateBase {
20362050
(void)Composite;
20372051
assert(getValueType(Composite)->isTypeArray() ||
20382052
getValueType(Composite)->isTypeStruct() ||
2039-
getValueType(Composite)->isTypeVector());
2053+
getValueType(Composite)->isTypeVector() ||
2054+
getValueType(Composite)->isTypeUntypedPointerKHR());
20402055
assert(Type == getValueType(Composite));
20412056
}
20422057
};
@@ -2383,7 +2398,8 @@ template <Op OC> class SPIRVLifetime : public SPIRVInstruction {
23832398
// Signedness of 1, its sign bit cannot be set.
23842399
if (!(ObjType->getPointerElementType()->isTypeVoid() ||
23852400
// (void *) is i8* in LLVM IR
2386-
ObjType->getPointerElementType()->isTypeInt(8)) ||
2401+
ObjType->getPointerElementType()->isTypeInt(8) ||
2402+
ObjType->getPointerElementType()->isTypeUntypedPointerKHR()) ||
23872403
!Module->hasCapability(CapabilityAddresses))
23882404
assert(Size == 0 && "Size must be 0");
23892405
}

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ class SPIRVModuleImpl : public SPIRVModule {
252252
SPIRVTypeInt *addIntegerType(unsigned BitWidth) override;
253253
SPIRVTypeOpaque *addOpaqueType(const std::string &) override;
254254
SPIRVTypePointer *addPointerType(SPIRVStorageClassKind, SPIRVType *) override;
255+
SPIRVTypeUntypedPointerKHR *
256+
addUntypedPointerKHRType(SPIRVStorageClassKind) override;
255257
SPIRVTypeImage *addImageType(SPIRVType *,
256258
const SPIRVTypeImageDescriptor &) override;
257259
SPIRVTypeImage *addImageType(SPIRVType *, const SPIRVTypeImageDescriptor &,
@@ -353,7 +355,7 @@ class SPIRVModuleImpl : public SPIRVModule {
353355
SPIRVInstruction *addCmpInst(Op, SPIRVType *, SPIRVValue *, SPIRVValue *,
354356
SPIRVBasicBlock *) override;
355357
SPIRVInstruction *addLoadInst(SPIRVValue *, const std::vector<SPIRVWord> &,
356-
SPIRVBasicBlock *) override;
358+
SPIRVBasicBlock *, SPIRVType *) override;
357359
SPIRVInstruction *addPhiInst(SPIRVType *, std::vector<SPIRVValue *>,
358360
SPIRVBasicBlock *) override;
359361
SPIRVInstruction *addCompositeConstructInst(SPIRVType *,
@@ -563,6 +565,8 @@ class SPIRVModuleImpl : public SPIRVModule {
563565
SPIRVUnknownStructFieldMap UnknownStructFieldMap;
564566
SPIRVTypeBool *BoolTy;
565567
SPIRVTypeVoid *VoidTy;
568+
SmallDenseMap<SPIRVStorageClassKind, SPIRVTypeUntypedPointerKHR *>
569+
UntypedPtrTyMap;
566570
SmallDenseMap<unsigned, SPIRVTypeInt *, 4> IntTypeMap;
567571
SmallDenseMap<unsigned, SPIRVTypeFloat *, 4> FloatTypeMap;
568572
SmallDenseMap<std::pair<unsigned, SPIRVType *>, SPIRVTypePointer *, 4>
@@ -1014,6 +1018,17 @@ SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass,
10141018
return addType(Ty);
10151019
}
10161020

1021+
SPIRVTypeUntypedPointerKHR *
1022+
SPIRVModuleImpl::addUntypedPointerKHRType(SPIRVStorageClassKind StorageClass) {
1023+
auto Loc = UntypedPtrTyMap.find(StorageClass);
1024+
if (Loc != UntypedPtrTyMap.end())
1025+
return Loc->second;
1026+
1027+
auto *Ty = new SPIRVTypeUntypedPointerKHR(this, getId(), StorageClass);
1028+
UntypedPtrTyMap[StorageClass] = Ty;
1029+
return addType(Ty);
1030+
}
1031+
10171032
SPIRVTypeFunction *SPIRVModuleImpl::addFunctionType(
10181033
SPIRVType *ReturnType, const std::vector<SPIRVType *> &ParameterTypes) {
10191034
return addType(
@@ -1430,9 +1445,10 @@ SPIRVModuleImpl::addInstruction(SPIRVInstruction *Inst, SPIRVBasicBlock *BB,
14301445
SPIRVInstruction *
14311446
SPIRVModuleImpl::addLoadInst(SPIRVValue *Source,
14321447
const std::vector<SPIRVWord> &TheMemoryAccess,
1433-
SPIRVBasicBlock *BB) {
1448+
SPIRVBasicBlock *BB, SPIRVType *TheType) {
14341449
return addInstruction(
1435-
new SPIRVLoad(getId(), Source->getId(), TheMemoryAccess, BB), BB);
1450+
new SPIRVLoad(getId(), Source->getId(), TheMemoryAccess, BB, TheType),
1451+
BB);
14361452
}
14371453

14381454
SPIRVInstruction *
@@ -1925,11 +1941,13 @@ class TopologicalSort {
19251941
// We've found a recursive data type, e.g. a structure having a member
19261942
// which is a pointer to the same structure.
19271943
State = Unvisited; // Forget about it
1928-
if (E->getOpCode() == OpTypePointer) {
1944+
if (E->getOpCode() == OpTypePointer ||
1945+
E->getOpCode() == OpTypeUntypedPointerKHR) {
19291946
// If we have a pointer in the recursive chain, we can break the
19301947
// cyclic dependency by inserting a forward declaration of that
19311948
// pointer.
1932-
SPIRVTypePointer *Ptr = static_cast<SPIRVTypePointer *>(E);
1949+
SPIRVTypePointerBase<> *Ptr =
1950+
static_cast<SPIRVTypePointerBase<> *>(E);
19331951
SPIRVModule *BM = E->getModule();
19341952
ForwardPointerSet.insert(BM->add(new SPIRVTypeForwardPointer(
19351953
BM, Ptr->getId(), Ptr->getPointerStorageClass())));

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class SPIRVTypeFunction;
7171
class SPIRVTypeInt;
7272
class SPIRVTypeOpaque;
7373
class SPIRVTypePointer;
74+
class SPIRVTypeUntypedPointerKHR;
7475
class SPIRVTypeImage;
7576
class SPIRVTypeSampler;
7677
class SPIRVTypeSampledImage;
@@ -257,6 +258,8 @@ class SPIRVModule {
257258
virtual SPIRVTypeOpaque *addOpaqueType(const std::string &) = 0;
258259
virtual SPIRVTypePointer *addPointerType(SPIRVStorageClassKind,
259260
SPIRVType *) = 0;
261+
virtual SPIRVTypeUntypedPointerKHR *
262+
addUntypedPointerKHRType(SPIRVStorageClassKind) = 0;
260263
virtual SPIRVTypeStruct *openStructType(unsigned, const std::string &) = 0;
261264
virtual SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) = 0;
262265
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;
@@ -396,7 +399,8 @@ class SPIRVModule {
396399
SPIRVBasicBlock *BB, SPIRVType *Ty) = 0;
397400
virtual SPIRVInstruction *addLoadInst(SPIRVValue *,
398401
const std::vector<SPIRVWord> &,
399-
SPIRVBasicBlock *) = 0;
402+
SPIRVBasicBlock *,
403+
SPIRVType *TheType = nullptr) = 0;
400404
virtual SPIRVInstruction *addLifetimeInst(Op OC, SPIRVValue *Object,
401405
SPIRVWord Size,
402406
SPIRVBasicBlock *BB) = 0;

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
463463
add(CapabilityRoundingModeRTZ, "RoundingModeRTZ");
464464
add(CapabilityRayQueryProvisionalKHR, "RayQueryProvisionalKHR");
465465
add(CapabilityRayQueryKHR, "RayQueryKHR");
466+
add(CapabilityUntypedPointersKHR, "UntypedPointersKHR");
466467
add(CapabilityRayTraversalPrimitiveCullingKHR,
467468
"RayTraversalPrimitiveCullingKHR");
468469
add(CapabilityRayTracingKHR, "RayTracingKHR");

lib/SPIRV/libSPIRV/SPIRVOpCode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ inline bool isTypeOpCode(Op OpCode) {
232232
OC == internal::OpTypeJointMatrixINTEL ||
233233
OC == internal::OpTypeJointMatrixINTELv2 ||
234234
OC == OpTypeCooperativeMatrixKHR ||
235-
OC == internal::OpTypeTaskSequenceINTEL;
235+
OC == internal::OpTypeTaskSequenceINTEL ||
236+
OC == OpTypeUntypedPointerKHR;
236237
}
237238

238239
inline bool isSpecConstantOpCode(Op OpCode) {

lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ _SPIRV_OP(PtrEqual, 401)
333333
_SPIRV_OP(PtrNotEqual, 402)
334334
_SPIRV_OP(PtrDiff, 403)
335335
_SPIRV_OP(CopyLogical, 400)
336+
_SPIRV_OP(TypeUntypedPointerKHR, 4417)
336337
_SPIRV_OP(GroupNonUniformRotateKHR, 4431)
337338
_SPIRV_OP(SDotKHR, 4450)
338339
_SPIRV_OP(UDotKHR, 4451)

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,16 @@ SPIRVType *SPIRVType::getFunctionReturnType() const {
8686
}
8787

8888
SPIRVType *SPIRVType::getPointerElementType() const {
89-
assert(OpCode == OpTypePointer && "Not a pointer type");
89+
assert((OpCode == OpTypePointer || OpCode == OpTypeUntypedPointerKHR) &&
90+
"Not a pointer type");
91+
if (OpCode == OpTypeUntypedPointerKHR)
92+
return const_cast<SPIRVType *>(this);
9093
return static_cast<const SPIRVTypePointer *>(this)->getElementType();
9194
}
9295

9396
SPIRVStorageClassKind SPIRVType::getPointerStorageClass() const {
94-
assert(OpCode == OpTypePointer && "Not a pointer type");
97+
assert((OpCode == OpTypePointer || OpCode == OpTypeUntypedPointerKHR) &&
98+
"Not a pointer type");
9599
return static_cast<const SPIRVTypePointer *>(this)->getStorageClass();
96100
}
97101

@@ -183,7 +187,13 @@ bool SPIRVType::isTypeInt(unsigned Bits) const {
183187
return isType<SPIRVTypeInt>(this, Bits);
184188
}
185189

186-
bool SPIRVType::isTypePointer() const { return OpCode == OpTypePointer; }
190+
bool SPIRVType::isTypePointer() const {
191+
return OpCode == OpTypePointer || OpCode == OpTypeUntypedPointerKHR;
192+
}
193+
194+
bool SPIRVType::isTypeUntypedPointerKHR() const {
195+
return OpCode == OpTypeUntypedPointerKHR;
196+
}
187197

188198
bool SPIRVType::isTypeOpaque() const { return OpCode == OpTypeOpaque; }
189199

0 commit comments

Comments
 (0)