Skip to content

[SPIR-V] Improve implementation of the duplicates tracker's storage #95958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 80 additions & 151 deletions llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,152 +52,87 @@ class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
void addDep(DTSortableEntry *E) { Deps.push_back(E); }
};

struct SpecialTypeDescriptor {
enum SpecialTypeKind {
STK_Empty = 0,
STK_Image,
STK_SampledImage,
STK_Sampler,
STK_Pipe,
STK_DeviceEvent,
STK_Pointer,
STK_Last = -1
};
SpecialTypeKind Kind;

unsigned Hash;

SpecialTypeDescriptor() = delete;
SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; }

unsigned getHash() const { return Hash; }

virtual ~SpecialTypeDescriptor() {}
};

struct ImageTypeDescriptor : public SpecialTypeDescriptor {
union ImageAttrs {
struct BitFlags {
unsigned Dim : 3;
unsigned Depth : 2;
unsigned Arrayed : 1;
unsigned MS : 1;
unsigned Sampled : 2;
unsigned ImageFormat : 6;
unsigned AQ : 2;
} Flags;
unsigned Val;
};

ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
unsigned Arrayed, unsigned MS, unsigned Sampled,
unsigned ImageFormat, unsigned AQ = 0)
: SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
ImageAttrs Attrs;
Attrs.Val = 0;
Attrs.Flags.Dim = Dim;
Attrs.Flags.Depth = Depth;
Attrs.Flags.Arrayed = Arrayed;
Attrs.Flags.MS = MS;
Attrs.Flags.Sampled = Sampled;
Attrs.Flags.ImageFormat = ImageFormat;
Attrs.Flags.AQ = AQ;
Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
((Attrs.Val << 8) | Kind);
}

static bool classof(const SpecialTypeDescriptor *TD) {
return TD->Kind == SpecialTypeKind::STK_Image;
}
};

struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
: SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
ImageTypeDescriptor TD(
SampledTy, ImageTy->getOperand(2).getImm(),
ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
Hash = TD.getHash() ^ Kind;
}

static bool classof(const SpecialTypeDescriptor *TD) {
return TD->Kind == SpecialTypeKind::STK_SampledImage;
}
};

struct SamplerTypeDescriptor : public SpecialTypeDescriptor {
SamplerTypeDescriptor()
: SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) {
Hash = Kind;
}

static bool classof(const SpecialTypeDescriptor *TD) {
return TD->Kind == SpecialTypeKind::STK_Sampler;
}
enum SpecialTypeKind {
STK_Empty = 0,
STK_Image,
STK_SampledImage,
STK_Sampler,
STK_Pipe,
STK_DeviceEvent,
STK_Pointer,
STK_Last = -1
};

struct PipeTypeDescriptor : public SpecialTypeDescriptor {

PipeTypeDescriptor(uint8_t AQ)
: SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) {
Hash = (AQ << 8) | Kind;
}

static bool classof(const SpecialTypeDescriptor *TD) {
return TD->Kind == SpecialTypeKind::STK_Pipe;
using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>;

union ImageAttrs {
struct BitFlags {
unsigned Dim : 3;
unsigned Depth : 2;
unsigned Arrayed : 1;
unsigned MS : 1;
unsigned Sampled : 2;
unsigned ImageFormat : 6;
unsigned AQ : 2;
} Flags;
unsigned Val;

ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
Val = 0;
Flags.Dim = Dim;
Flags.Depth = Depth;
Flags.Arrayed = Arrayed;
Flags.MS = MS;
Flags.Sampled = Sampled;
Flags.ImageFormat = ImageFormat;
Flags.AQ = AQ;
}
};

struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {

DeviceEventTypeDescriptor()
: SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) {
Hash = Kind;
}

static bool classof(const SpecialTypeDescriptor *TD) {
return TD->Kind == SpecialTypeKind::STK_DeviceEvent;
}
};

struct PointerTypeDescriptor : public SpecialTypeDescriptor {
const Type *ElementType;
unsigned AddressSpace;

PointerTypeDescriptor() = delete;
PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace)
: SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer),
ElementType(ElementType), AddressSpace(AddressSpace) {
Hash = (DenseMapInfo<Type *>().getHashValue(ElementType) & 0xffff) ^
((AddressSpace << 8) | Kind);
}

static bool classof(const SpecialTypeDescriptor *TD) {
return TD->Kind == SpecialTypeKind::STK_Pointer;
}
};
inline SpecialTypeDescriptor
make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth,
unsigned Arrayed, unsigned MS, unsigned Sampled,
unsigned ImageFormat, unsigned AQ = 0) {
return std::make_tuple(
SampledTy,
ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
SpecialTypeKind::STK_Image);
}

inline SpecialTypeDescriptor
make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
return std::make_tuple(
SampledTy,
ImageAttrs(
ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(),
ImageTy->getOperand(8).getImm())
.Val,
SpecialTypeKind::STK_SampledImage);
}

inline SpecialTypeDescriptor make_descr_sampler() {
return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
}

inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) {
return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
}

inline SpecialTypeDescriptor make_descr_event() {
return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
}

inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType,
unsigned AddressSpace) {
return std::make_tuple(ElementType, AddressSpace,
SpecialTypeKind::STK_Pointer);
}
} // namespace SPIRV

template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> {
static inline SPIRV::SpecialTypeDescriptor getEmptyKey() {
return SPIRV::SpecialTypeDescriptor(
SPIRV::SpecialTypeDescriptor::STK_Empty);
}
static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() {
return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last);
}
static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) {
return Val.getHash();
}
static bool isEqual(SPIRV::SpecialTypeDescriptor LHS,
SPIRV::SpecialTypeDescriptor RHS) {
return getHashValue(LHS) == getHashValue(RHS);
}
};

template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
public:
// NOTE: using MapVector instead of DenseMap helps getting everything ordered
Expand Down Expand Up @@ -283,16 +218,13 @@ class SPIRVGeneralDuplicatesTracker {
MachineModuleInfo *MMI);

void add(const Type *Ty, const MachineFunction *MF, Register R) {
TT.add(Ty, MF, R);
TT.add(unifyPtrType(Ty), MF, R);
}

void add(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF, Register R) {
if (isUntypedPointerTy(PointeeTy))
PointeeTy =
TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
getPointerAddressSpace(PointeeTy));
ST.add(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF, R);
ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF,
R);
}

void add(const Constant *C, const MachineFunction *MF, Register R) {
Expand Down Expand Up @@ -321,16 +253,13 @@ class SPIRVGeneralDuplicatesTracker {
}

Register find(const Type *Ty, const MachineFunction *MF) {
return TT.find(const_cast<Type *>(Ty), MF);
return TT.find(unifyPtrType(Ty), MF);
}

Register find(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF) {
if (isUntypedPointerTy(PointeeTy))
PointeeTy =
TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
getPointerAddressSpace(PointeeTy));
return ST.find(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF);
return ST.find(
SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF);
}

Register find(const Constant *C, const MachineFunction *MF) {
Expand Down
18 changes: 9 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
TypesInProcessing.erase(Ty);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = Ty;
SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
// Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
// will be added later. For special types it is already added to DT.
Expand Down Expand Up @@ -1122,9 +1122,9 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
SPIRV::ImageFormat::ImageFormat ImageFormat,
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
Arrayed, Multisampled, Sampled, ImageFormat,
AccessQual);
auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
Depth, Arrayed, Multisampled, Sampled,
ImageFormat, AccessQual);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
Expand All @@ -1143,7 +1143,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(

SPIRVType *
SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
SPIRV::SamplerTypeDescriptor TD;
auto TD = SPIRV::make_descr_sampler();
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
Expand All @@ -1154,7 +1154,7 @@ SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
SPIRV::PipeTypeDescriptor TD(AccessQual);
auto TD = SPIRV::make_descr_pipe(AccessQual);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
Expand All @@ -1166,7 +1166,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(

SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
MachineIRBuilder &MIRBuilder) {
SPIRV::DeviceEventTypeDescriptor TD;
auto TD = SPIRV::make_descr_event();
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
Expand All @@ -1176,7 +1176,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(

SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
SPIRV::SampledImageTypeDescriptor TD(
auto TD = SPIRV::make_descr_sampled_image(
SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
ImageType->getOperand(1).getReg())),
ImageType);
Expand Down Expand Up @@ -1268,7 +1268,7 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
SPIRVType *SpirvType) {
assert(CurMF == SpirvType->getMF());
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = LLVMTy;
SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
return SpirvType;
}

Expand Down
24 changes: 24 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,29 @@ inline Type *toTypedPointer(Type *Ty) {
: Ty;
}

inline Type *toTypedFunPointer(FunctionType *FTy) {
Type *OrigRetTy = FTy->getReturnType();
Type *RetTy = toTypedPointer(OrigRetTy);
bool IsUntypedPtr = false;
for (Type *PTy : FTy->params()) {
if (isUntypedPointerTy(PTy)) {
IsUntypedPtr = true;
break;
}
}
if (!IsUntypedPtr && RetTy == OrigRetTy)
return FTy;
SmallVector<Type *> ParamTys;
for (Type *PTy : FTy->params())
ParamTys.push_back(toTypedPointer(PTy));
return FunctionType::get(RetTy, ParamTys, FTy->isVarArg());
}

inline const Type *unifyPtrType(const Type *Ty) {
if (auto FTy = dyn_cast<FunctionType>(Ty))
return toTypedFunPointer(const_cast<FunctionType *>(FTy));
return toTypedPointer(const_cast<Type *>(Ty));
}

} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
Loading