Skip to content

Commit 19d67bd

Browse files
authored
Refactoring Two functions into One (#3124)
- Refactored addUntypedPointerKHRType and addPointerType functions into a single function (It was marked as TODO) - Combined both functions into one with return type SPIRVType * - Also re inserted an assert and removed unnecessary if (Marked as TODO)
1 parent 024ae2b commit 19d67bd

File tree

4 files changed

+31
-35
lines changed

4 files changed

+31
-35
lines changed

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -735,8 +735,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
735735
SPIRVType *TranslatedTy = nullptr;
736736
if (ET->isPointerTy() &&
737737
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers)) {
738-
TranslatedTy = BM->addUntypedPointerKHRType(
739-
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
738+
TranslatedTy = BM->addPointerType(
739+
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)),
740+
nullptr);
740741
} else {
741742
ElementType = transType(ET);
742743
TranslatedTy = transPointerType(ElementType, AddrSpc);
@@ -761,8 +762,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
761762
return transPointerType(ET, SPIRAS_Private);
762763
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) &&
763764
!(ET->isTypeArray() || ET->isTypeVector() || ET->isSPIRVOpaqueType())) {
764-
TranslatedTy = BM->addUntypedPointerKHRType(
765-
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
765+
TranslatedTy = BM->addPointerType(
766+
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)),
767+
nullptr);
766768
} else {
767769
TranslatedTy = BM->addPointerType(
768770
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)), ET);
@@ -2347,10 +2349,8 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
23472349
}
23482350
SPIRVType *VarTy = TranslatedTy;
23492351
if (V->getType()->getPointerAddressSpace() == SPIRAS_Generic) {
2350-
// TODO: refactor addPointerType and addUntypedPointerKHRType in one
2351-
// method if possible.
23522352
if (TranslatedTy->isTypeUntypedPointerKHR())
2353-
VarTy = BM->addUntypedPointerKHRType(StorageClassFunction);
2353+
VarTy = BM->addPointerType(StorageClassFunction, nullptr);
23542354
else
23552355
VarTy = BM->addPointerType(StorageClassFunction,
23562356
TranslatedTy->getPointerElementType());
@@ -2697,11 +2697,8 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
26972697
SPIRVType *LLVMToSPIRVBase::mapType(Type *T, SPIRVType *BT) {
26982698
assert(!T->isPointerTy() && "Pointer types cannot be stored in the type map");
26992699
auto EmplaceStatus = TypeMap.try_emplace(T, BT);
2700-
// TODO: Uncomment the assertion, once the type mapping issue is resolved
2701-
// assert(EmplaceStatus.second && "The type was already added to the map");
2700+
assert(EmplaceStatus.second && "The type was already added to the map");
27022701
SPIRVDBG(dbgs() << "[mapType] " << *T << " => "; spvdbgs() << *BT << '\n');
2703-
if (!EmplaceStatus.second)
2704-
return TypeMap[T];
27052702
return BT;
27062703
}
27072704

@@ -4302,8 +4299,8 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
43024299
SPIRVType *IntegralTy = transType(II->getType()->getStructElementType(1));
43034300
// IntegralTy is the type of the result. We want to create a pointer to this
43044301
// that we can pass to OpenCLLIB::modf to store the integral part.
4305-
SPIRVTypePointer *IntegralPtrTy =
4306-
BM->addPointerType(StorageClassFunction, IntegralTy);
4302+
SPIRVType *GenericPtrTy = BM->addPointerType(StorageClassFunction, IntegralTy);
4303+
auto *IntegralPtrTy = dyn_cast<SPIRVTypePointer>(GenericPtrTy);
43074304
// We need to use the entry BB of the function calling llvm.modf.*, instead
43084305
// of the current BB. For that, we'll find current BB's parent and get its
43094306
// first BB, which is the entry BB of the function.
@@ -4829,7 +4826,7 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
48294826
auto *SrcTy = PtrOp->getType();
48304827
SPIRVType *DstTy = nullptr;
48314828
if (SrcTy->isTypeUntypedPointerKHR())
4832-
DstTy = BM->addUntypedPointerKHRType(StorageClassFunction);
4829+
DstTy = BM->addPointerType(StorageClassFunction, nullptr);
48334830
else
48344831
DstTy = BM->addPointerType(StorageClassFunction,
48354832
SrcTy->getPointerElementType());

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,7 @@ class SPIRVModuleImpl : public SPIRVModule {
259259
const std::vector<SPIRVType *> &) override;
260260
SPIRVTypeInt *addIntegerType(unsigned BitWidth) override;
261261
SPIRVTypeOpaque *addOpaqueType(const std::string &) override;
262-
SPIRVTypePointer *addPointerType(SPIRVStorageClassKind, SPIRVType *) override;
263-
SPIRVTypeUntypedPointerKHR *
264-
addUntypedPointerKHRType(SPIRVStorageClassKind) override;
262+
SPIRVType *addPointerType(SPIRVStorageClassKind, SPIRVType *) override;
265263
SPIRVTypeImage *addImageType(SPIRVType *,
266264
const SPIRVTypeImageDescriptor &) override;
267265
SPIRVTypeImage *addImageType(SPIRVType *, const SPIRVTypeImageDescriptor &,
@@ -1023,29 +1021,30 @@ SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth,
10231021
return addType(Ty);
10241022
}
10251023

1026-
SPIRVTypePointer *
1027-
SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass,
1028-
SPIRVType *ElementType) {
1024+
SPIRVType *SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass,
1025+
SPIRVType *ElementType = nullptr) {
1026+
if (ElementType == nullptr) {
1027+
// Untyped pointer
1028+
auto Loc = UntypedPtrTyMap.find(StorageClass);
1029+
if (Loc != UntypedPtrTyMap.end())
1030+
return Loc->second;
1031+
1032+
auto *Ty = new SPIRVTypeUntypedPointerKHR(this, getId(), StorageClass);
1033+
UntypedPtrTyMap[StorageClass] = Ty;
1034+
return addType(Ty);
1035+
}
1036+
1037+
// Typed pointer
10291038
auto Desc = std::make_pair(StorageClass, ElementType);
10301039
auto Loc = PointerTypeMap.find(Desc);
10311040
if (Loc != PointerTypeMap.end())
10321041
return Loc->second;
1042+
10331043
auto *Ty = new SPIRVTypePointer(this, getId(), StorageClass, ElementType);
10341044
PointerTypeMap[Desc] = Ty;
10351045
return addType(Ty);
10361046
}
10371047

1038-
SPIRVTypeUntypedPointerKHR *
1039-
SPIRVModuleImpl::addUntypedPointerKHRType(SPIRVStorageClassKind StorageClass) {
1040-
auto Loc = UntypedPtrTyMap.find(StorageClass);
1041-
if (Loc != UntypedPtrTyMap.end())
1042-
return Loc->second;
1043-
1044-
auto *Ty = new SPIRVTypeUntypedPointerKHR(this, getId(), StorageClass);
1045-
UntypedPtrTyMap[StorageClass] = Ty;
1046-
return addType(Ty);
1047-
}
1048-
10491048
SPIRVTypeFunction *SPIRVModuleImpl::addFunctionType(
10501049
SPIRVType *ReturnType, const std::vector<SPIRVType *> &ParameterTypes) {
10511050
return addType(

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,7 @@ class SPIRVModule {
257257
virtual SPIRVTypeSampledImage *addSampledImageType(SPIRVTypeImage *T) = 0;
258258
virtual SPIRVTypeInt *addIntegerType(unsigned) = 0;
259259
virtual SPIRVTypeOpaque *addOpaqueType(const std::string &) = 0;
260-
virtual SPIRVTypePointer *addPointerType(SPIRVStorageClassKind,
261-
SPIRVType *) = 0;
262-
virtual SPIRVTypeUntypedPointerKHR *
263-
addUntypedPointerKHRType(SPIRVStorageClassKind) = 0;
260+
virtual SPIRVType *addPointerType(SPIRVStorageClassKind, SPIRVType *) = 0;
264261
virtual SPIRVTypeStruct *openStructType(unsigned, const std::string &) = 0;
265262
virtual SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) = 0;
266263
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ class SPIRVTypePointer : public SPIRVTypePointerBase<OpTypePointer, 4> {
323323
std::vector<SPIRVEntry *> getNonLiteralOperands() const override {
324324
return std::vector<SPIRVEntry *>(1, getEntry(ElemTypeId));
325325
}
326+
static bool classof(const SPIRVEntry *E) {
327+
return E->getOpCode() == OpTypePointer;
328+
}
326329

327330
protected:
328331
_SPIRV_DEF_ENCDEC3(Id, ElemStorageClass, ElemTypeId)

0 commit comments

Comments
 (0)