Skip to content

Commit 0daeed6

Browse files
[SPIR-V] Improve implementation of the duplicates tracker's storage (#95958)
This PR continues #94952, managing FunctionType in the same way as a pointee types in #94952 (that is working with TypedPointers pointee types rather than with original llvm's untyped pointers). This PR also fully reworks the base type for the duplicates tracker's storage to conform with and reuse DenseMapInfo. Previous implementation didn't store enough info to differ between key values (see isEqual() implemented as equality of derived from arguments hash values). This, in turn, led to random crashes in very rare occasions when hash value of an actual key matched hash values of empty and tombstone instances. In this PR we use std::tuple instead of a tailor-made class hierarchy, both reusing DenseMapInfo templates and getting rid of the crash condition.
1 parent 0dd4377 commit 0daeed6

File tree

3 files changed

+113
-160
lines changed

3 files changed

+113
-160
lines changed

llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h

Lines changed: 80 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -52,152 +52,87 @@ class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
5252
void addDep(DTSortableEntry *E) { Deps.push_back(E); }
5353
};
5454

55-
struct SpecialTypeDescriptor {
56-
enum SpecialTypeKind {
57-
STK_Empty = 0,
58-
STK_Image,
59-
STK_SampledImage,
60-
STK_Sampler,
61-
STK_Pipe,
62-
STK_DeviceEvent,
63-
STK_Pointer,
64-
STK_Last = -1
65-
};
66-
SpecialTypeKind Kind;
67-
68-
unsigned Hash;
69-
70-
SpecialTypeDescriptor() = delete;
71-
SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; }
72-
73-
unsigned getHash() const { return Hash; }
74-
75-
virtual ~SpecialTypeDescriptor() {}
76-
};
77-
78-
struct ImageTypeDescriptor : public SpecialTypeDescriptor {
79-
union ImageAttrs {
80-
struct BitFlags {
81-
unsigned Dim : 3;
82-
unsigned Depth : 2;
83-
unsigned Arrayed : 1;
84-
unsigned MS : 1;
85-
unsigned Sampled : 2;
86-
unsigned ImageFormat : 6;
87-
unsigned AQ : 2;
88-
} Flags;
89-
unsigned Val;
90-
};
91-
92-
ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
93-
unsigned Arrayed, unsigned MS, unsigned Sampled,
94-
unsigned ImageFormat, unsigned AQ = 0)
95-
: SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
96-
ImageAttrs Attrs;
97-
Attrs.Val = 0;
98-
Attrs.Flags.Dim = Dim;
99-
Attrs.Flags.Depth = Depth;
100-
Attrs.Flags.Arrayed = Arrayed;
101-
Attrs.Flags.MS = MS;
102-
Attrs.Flags.Sampled = Sampled;
103-
Attrs.Flags.ImageFormat = ImageFormat;
104-
Attrs.Flags.AQ = AQ;
105-
Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
106-
((Attrs.Val << 8) | Kind);
107-
}
108-
109-
static bool classof(const SpecialTypeDescriptor *TD) {
110-
return TD->Kind == SpecialTypeKind::STK_Image;
111-
}
112-
};
113-
114-
struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
115-
SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
116-
: SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
117-
assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
118-
ImageTypeDescriptor TD(
119-
SampledTy, ImageTy->getOperand(2).getImm(),
120-
ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
121-
ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
122-
ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
123-
Hash = TD.getHash() ^ Kind;
124-
}
125-
126-
static bool classof(const SpecialTypeDescriptor *TD) {
127-
return TD->Kind == SpecialTypeKind::STK_SampledImage;
128-
}
129-
};
130-
131-
struct SamplerTypeDescriptor : public SpecialTypeDescriptor {
132-
SamplerTypeDescriptor()
133-
: SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) {
134-
Hash = Kind;
135-
}
136-
137-
static bool classof(const SpecialTypeDescriptor *TD) {
138-
return TD->Kind == SpecialTypeKind::STK_Sampler;
139-
}
55+
enum SpecialTypeKind {
56+
STK_Empty = 0,
57+
STK_Image,
58+
STK_SampledImage,
59+
STK_Sampler,
60+
STK_Pipe,
61+
STK_DeviceEvent,
62+
STK_Pointer,
63+
STK_Last = -1
14064
};
14165

142-
struct PipeTypeDescriptor : public SpecialTypeDescriptor {
143-
144-
PipeTypeDescriptor(uint8_t AQ)
145-
: SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) {
146-
Hash = (AQ << 8) | Kind;
147-
}
148-
149-
static bool classof(const SpecialTypeDescriptor *TD) {
150-
return TD->Kind == SpecialTypeKind::STK_Pipe;
66+
using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>;
67+
68+
union ImageAttrs {
69+
struct BitFlags {
70+
unsigned Dim : 3;
71+
unsigned Depth : 2;
72+
unsigned Arrayed : 1;
73+
unsigned MS : 1;
74+
unsigned Sampled : 2;
75+
unsigned ImageFormat : 6;
76+
unsigned AQ : 2;
77+
} Flags;
78+
unsigned Val;
79+
80+
ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
81+
unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
82+
Val = 0;
83+
Flags.Dim = Dim;
84+
Flags.Depth = Depth;
85+
Flags.Arrayed = Arrayed;
86+
Flags.MS = MS;
87+
Flags.Sampled = Sampled;
88+
Flags.ImageFormat = ImageFormat;
89+
Flags.AQ = AQ;
15190
}
15291
};
15392

154-
struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
155-
156-
DeviceEventTypeDescriptor()
157-
: SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) {
158-
Hash = Kind;
159-
}
160-
161-
static bool classof(const SpecialTypeDescriptor *TD) {
162-
return TD->Kind == SpecialTypeKind::STK_DeviceEvent;
163-
}
164-
};
165-
166-
struct PointerTypeDescriptor : public SpecialTypeDescriptor {
167-
const Type *ElementType;
168-
unsigned AddressSpace;
169-
170-
PointerTypeDescriptor() = delete;
171-
PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace)
172-
: SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer),
173-
ElementType(ElementType), AddressSpace(AddressSpace) {
174-
Hash = (DenseMapInfo<Type *>().getHashValue(ElementType) & 0xffff) ^
175-
((AddressSpace << 8) | Kind);
176-
}
177-
178-
static bool classof(const SpecialTypeDescriptor *TD) {
179-
return TD->Kind == SpecialTypeKind::STK_Pointer;
180-
}
181-
};
93+
inline SpecialTypeDescriptor
94+
make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth,
95+
unsigned Arrayed, unsigned MS, unsigned Sampled,
96+
unsigned ImageFormat, unsigned AQ = 0) {
97+
return std::make_tuple(
98+
SampledTy,
99+
ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
100+
SpecialTypeKind::STK_Image);
101+
}
102+
103+
inline SpecialTypeDescriptor
104+
make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
105+
assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
106+
return std::make_tuple(
107+
SampledTy,
108+
ImageAttrs(
109+
ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
110+
ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
111+
ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(),
112+
ImageTy->getOperand(8).getImm())
113+
.Val,
114+
SpecialTypeKind::STK_SampledImage);
115+
}
116+
117+
inline SpecialTypeDescriptor make_descr_sampler() {
118+
return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
119+
}
120+
121+
inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) {
122+
return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
123+
}
124+
125+
inline SpecialTypeDescriptor make_descr_event() {
126+
return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
127+
}
128+
129+
inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType,
130+
unsigned AddressSpace) {
131+
return std::make_tuple(ElementType, AddressSpace,
132+
SpecialTypeKind::STK_Pointer);
133+
}
182134
} // namespace SPIRV
183135

184-
template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> {
185-
static inline SPIRV::SpecialTypeDescriptor getEmptyKey() {
186-
return SPIRV::SpecialTypeDescriptor(
187-
SPIRV::SpecialTypeDescriptor::STK_Empty);
188-
}
189-
static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() {
190-
return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last);
191-
}
192-
static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) {
193-
return Val.getHash();
194-
}
195-
static bool isEqual(SPIRV::SpecialTypeDescriptor LHS,
196-
SPIRV::SpecialTypeDescriptor RHS) {
197-
return getHashValue(LHS) == getHashValue(RHS);
198-
}
199-
};
200-
201136
template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
202137
public:
203138
// NOTE: using MapVector instead of DenseMap helps getting everything ordered
@@ -283,16 +218,13 @@ class SPIRVGeneralDuplicatesTracker {
283218
MachineModuleInfo *MMI);
284219

285220
void add(const Type *Ty, const MachineFunction *MF, Register R) {
286-
TT.add(Ty, MF, R);
221+
TT.add(unifyPtrType(Ty), MF, R);
287222
}
288223

289224
void add(const Type *PointeeTy, unsigned AddressSpace,
290225
const MachineFunction *MF, Register R) {
291-
if (isUntypedPointerTy(PointeeTy))
292-
PointeeTy =
293-
TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
294-
getPointerAddressSpace(PointeeTy));
295-
ST.add(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF, R);
226+
ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF,
227+
R);
296228
}
297229

298230
void add(const Constant *C, const MachineFunction *MF, Register R) {
@@ -321,16 +253,13 @@ class SPIRVGeneralDuplicatesTracker {
321253
}
322254

323255
Register find(const Type *Ty, const MachineFunction *MF) {
324-
return TT.find(const_cast<Type *>(Ty), MF);
256+
return TT.find(unifyPtrType(Ty), MF);
325257
}
326258

327259
Register find(const Type *PointeeTy, unsigned AddressSpace,
328260
const MachineFunction *MF) {
329-
if (isUntypedPointerTy(PointeeTy))
330-
PointeeTy =
331-
TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
332-
getPointerAddressSpace(PointeeTy));
333-
return ST.find(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF);
261+
return ST.find(
262+
SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF);
334263
}
335264

336265
Register find(const Constant *C, const MachineFunction *MF) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
936936
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
937937
TypesInProcessing.erase(Ty);
938938
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
939-
SPIRVToLLVMType[SpirvType] = Ty;
939+
SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
940940
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
941941
// Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
942942
// will be added later. For special types it is already added to DT.
@@ -1122,9 +1122,9 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
11221122
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
11231123
SPIRV::ImageFormat::ImageFormat ImageFormat,
11241124
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1125-
SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
1126-
Arrayed, Multisampled, Sampled, ImageFormat,
1127-
AccessQual);
1125+
auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
1126+
Depth, Arrayed, Multisampled, Sampled,
1127+
ImageFormat, AccessQual);
11281128
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
11291129
return Res;
11301130
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1143,7 +1143,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
11431143

11441144
SPIRVType *
11451145
SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1146-
SPIRV::SamplerTypeDescriptor TD;
1146+
auto TD = SPIRV::make_descr_sampler();
11471147
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
11481148
return Res;
11491149
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1154,7 +1154,7 @@ SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
11541154
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
11551155
MachineIRBuilder &MIRBuilder,
11561156
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1157-
SPIRV::PipeTypeDescriptor TD(AccessQual);
1157+
auto TD = SPIRV::make_descr_pipe(AccessQual);
11581158
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
11591159
return Res;
11601160
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1166,7 +1166,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
11661166

11671167
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
11681168
MachineIRBuilder &MIRBuilder) {
1169-
SPIRV::DeviceEventTypeDescriptor TD;
1169+
auto TD = SPIRV::make_descr_event();
11701170
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
11711171
return Res;
11721172
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1176,7 +1176,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
11761176

11771177
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
11781178
SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1179-
SPIRV::SampledImageTypeDescriptor TD(
1179+
auto TD = SPIRV::make_descr_sampled_image(
11801180
SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
11811181
ImageType->getOperand(1).getReg())),
11821182
ImageType);
@@ -1268,7 +1268,7 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
12681268
SPIRVType *SpirvType) {
12691269
assert(CurMF == SpirvType->getMF());
12701270
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1271-
SPIRVToLLVMType[SpirvType] = LLVMTy;
1271+
SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
12721272
return SpirvType;
12731273
}
12741274

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,29 @@ inline Type *toTypedPointer(Type *Ty) {
160160
: Ty;
161161
}
162162

163+
inline Type *toTypedFunPointer(FunctionType *FTy) {
164+
Type *OrigRetTy = FTy->getReturnType();
165+
Type *RetTy = toTypedPointer(OrigRetTy);
166+
bool IsUntypedPtr = false;
167+
for (Type *PTy : FTy->params()) {
168+
if (isUntypedPointerTy(PTy)) {
169+
IsUntypedPtr = true;
170+
break;
171+
}
172+
}
173+
if (!IsUntypedPtr && RetTy == OrigRetTy)
174+
return FTy;
175+
SmallVector<Type *> ParamTys;
176+
for (Type *PTy : FTy->params())
177+
ParamTys.push_back(toTypedPointer(PTy));
178+
return FunctionType::get(RetTy, ParamTys, FTy->isVarArg());
179+
}
180+
181+
inline const Type *unifyPtrType(const Type *Ty) {
182+
if (auto FTy = dyn_cast<FunctionType>(Ty))
183+
return toTypedFunPointer(const_cast<FunctionType *>(FTy));
184+
return toTypedPointer(const_cast<Type *>(Ty));
185+
}
186+
163187
} // namespace llvm
164188
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H

0 commit comments

Comments
 (0)