Skip to content

[SPIRV] Add explicit layout #135789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 188 additions & 105 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Large diffs are not rendered by default.

40 changes: 31 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
bool EmitIR);
bool ExplicitLayoutRequired, bool EmitIR);
SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier accessQual,
bool EmitIR);
bool ExplicitLayoutRequired, bool EmitIR);
SPIRVType *
restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual,
bool EmitIR);
bool ExplicitLayoutRequired, bool EmitIR);

// Internal function creating the an OpType at the correct position in the
// function by tweaking the passed "MIRBuilder" insertion point and restoring
Expand Down Expand Up @@ -298,10 +298,19 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes)
// because this method may be called from InstructionSelector and we don't
// want to emit extra IR instructions there.
SPIRVType *getOrCreateSPIRVType(const Type *Type, MachineInstr &I,
SPIRV::AccessQualifier::AccessQualifier AQ,
bool EmitIR) {
MachineIRBuilder MIRBuilder(I);
return getOrCreateSPIRVType(Type, MIRBuilder, AQ, EmitIR);
}

SPIRVType *getOrCreateSPIRVType(const Type *Type,
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
bool EmitIR);
bool EmitIR) {
return getOrCreateSPIRVType(Type, MIRBuilder, AQ, false, EmitIR);
}

const Type *getTypeForSPIRVType(const SPIRVType *Ty) const {
auto Res = SPIRVToLLVMType.find(Ty);
Expand Down Expand Up @@ -364,6 +373,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;

// Returns true if `Type` is a resource type. This could be an image type
// or a struct for a buffer decorated with the block decoration.
bool isResourceType(SPIRVType *Type) const;

// Return number of elements in a vector if the argument is associated with
// a vector type. Return 1 for a scalar type, and 0 for a missing type.
unsigned getScalarOrVectorComponentCount(Register VReg) const;
Expand Down Expand Up @@ -414,6 +427,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
const Type *adjustIntTypeByWidth(const Type *Ty) const;
unsigned adjustOpTypeIntWidth(unsigned Width) const;

SPIRVType *getOrCreateSPIRVType(const Type *Type,
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
bool ExplicitLayoutRequired, bool EmitIR);

SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);

Expand All @@ -425,14 +443,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
MachineIRBuilder &MIRBuilder);

SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder, bool EmitIR);
MachineIRBuilder &MIRBuilder,
bool ExplicitLayoutRequired, bool EmitIR);

SPIRVType *getOpTypeOpaque(const StructType *Ty,
MachineIRBuilder &MIRBuilder);

SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual,
bool EmitIR);
bool ExplicitLayoutRequired, bool EmitIR);

SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
Expand Down Expand Up @@ -475,6 +494,12 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC);

void addStructOffsetDecorations(Register Reg, StructType *Ty,
MachineIRBuilder &MIRBuilder);
void addArrayStrideDecorations(Register Reg, Type *ElementType,
MachineIRBuilder &MIRBuilder);
bool hasBlockDecoration(SPIRVType *Type) const;

public:
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
Expand Down Expand Up @@ -545,9 +570,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII);
SPIRVType *getOrCreateSPIRVArrayType(SPIRVType *BaseType,
unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII);

// Returns a pointer to a SPIR-V pointer type with the given base type and
// storage class. The base type will be translated to a SPIR-V type, and the
Expand Down
55 changes: 50 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVIRMapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum SpecialTypeKind {
STK_Value,
STK_MachineInstr,
STK_VkBuffer,
STK_ExplictLayoutType,
STK_Last = -1
};

Expand Down Expand Up @@ -150,6 +151,11 @@ inline IRHandle irhandle_vkbuffer(const Type *ElementType,
SpecialTypeKind::STK_VkBuffer);
}

inline IRHandle irhandle_explict_layout_type(const Type *Ty) {
const Type *WrpTy = unifyPtrType(Ty);
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_ExplictLayoutType);
}

inline IRHandle handle(const Type *Ty) {
const Type *WrpTy = unifyPtrType(Ty);
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_Type);
Expand All @@ -163,6 +169,10 @@ inline IRHandle handle(const MachineInstr *KeyMI) {
return irhandle_ptr(KeyMI, SPIRV::to_hash(KeyMI), STK_MachineInstr);
}

inline bool type_has_layout_decoration(const Type *T) {
return (isa<StructType>(T) || isa<ArrayType>(T));
}

} // namespace SPIRV

// Bi-directional mappings between LLVM entities and (v-reg, machine function)
Expand Down Expand Up @@ -238,14 +248,49 @@ class SPIRVIRMapping {
return findMI(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF);
}

template <typename T> bool add(const T *Obj, const MachineInstr *MI) {
bool add(const Value *V, const MachineInstr *MI) {
return add(SPIRV::handle(V), MI);
}

bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) {
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) {
return add(SPIRV::irhandle_explict_layout_type(T), MI);
}
return add(SPIRV::handle(T), MI);
}

bool add(const MachineInstr *Obj, const MachineInstr *MI) {
return add(SPIRV::handle(Obj), MI);
}
template <typename T> Register find(const T *Obj, const MachineFunction *MF) {
return find(SPIRV::handle(Obj), MF);

Register find(const Value *V, const MachineFunction *MF) {
return find(SPIRV::handle(V), MF);
}

Register find(const Type *T, bool RequiresExplicitLayout,
const MachineFunction *MF) {
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
return find(SPIRV::irhandle_explict_layout_type(T), MF);
return find(SPIRV::handle(T), MF);
}

Register find(const MachineInstr *MI, const MachineFunction *MF) {
return find(SPIRV::handle(MI), MF);
}

const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) {
return findMI(SPIRV::handle(Obj), MF);
}

const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout,
const MachineFunction *MF) {
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
return findMI(SPIRV::irhandle_explict_layout_type(T), MF);
return findMI(SPIRV::handle(T), MF);
}
template <typename T>
const MachineInstr *findMI(const T *Obj, const MachineFunction *MF) {

const MachineInstr *findMI(const MachineInstr *Obj,
const MachineFunction *MF) {
return findMI(SPIRV::handle(Obj), MF);
}
};
Expand Down
94 changes: 94 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,42 @@

using namespace llvm;

// Returns true of the types logically match, as defined in
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2,
SPIRVGlobalRegistry &GR) {
if (Ty1->getOpcode() != Ty2->getOpcode())
return false;

if (Ty1->getNumOperands() != Ty2->getNumOperands())
return false;

if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
// Array must have the same size.
if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg())
return false;

SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg());
SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg());
return ElemType1 == ElemType2 ||
typesLogicallyMatch(ElemType1, ElemType2, GR);
}

if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
SPIRVType *ElemType1 =
GR.getSPIRVTypeForVReg(Ty1->getOperand(I).getReg());
SPIRVType *ElemType2 =
GR.getSPIRVTypeForVReg(Ty2->getOperand(I).getReg());
if (ElemType1 != ElemType2 &&
!typesLogicallyMatch(ElemType1, ElemType2, GR))
return false;
}
return true;
}
return false;
}

unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
// This code avoids CallLowering fail inside getVectorTypeBreakdown
Expand Down Expand Up @@ -374,6 +410,9 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
// implies that %Op is a pointer to <ResType>
case SPIRV::OpLoad:
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
if (enforcePtrTypeCompatibility(MI, 2, 0))
break;

validatePtrTypes(STI, MRI, GR, MI, 2,
GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
break;
Expand Down Expand Up @@ -531,3 +570,58 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
ProcessedMF.insert(&MF);
TargetLowering::finalizeLowering(MF);
}

// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
// match or if the instruction was modified to make them match.
bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
SPIRVType *PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg());
SPIRVType *PointeeType = GR.getPointeeType(PtrType);
SPIRVType *OpType = GR.getResultType(I.getOperand(OpIdx).getReg());

if (PointeeType == OpType)
return true;

if (typesLogicallyMatch(PointeeType, OpType, GR)) {
// Apply OpCopyLogical to OpIdx.
if (I.getOperand(OpIdx).isDef() &&
insertLogicalCopyOnResult(I, PointeeType)) {
return true;
}

llvm_unreachable("Unable to add OpCopyLogical yet.");
return false;
}

return false;
}

bool SPIRVTargetLowering::insertLogicalCopyOnResult(
MachineInstr &I, SPIRVType *NewResultType) const {
MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();

Register NewResultReg =
createVirtualRegister(NewResultType, &GR, MRI, *I.getMF());
Register NewTypeReg = GR.getSPIRVTypeID(NewResultType);

assert(std::distance(I.defs().begin(), I.defs().end()) == 1 &&
"Expected only one def");
MachineOperand &OldResult = *I.defs().begin();
Register OldResultReg = OldResult.getReg();
MachineOperand &OldType = *I.uses().begin();
Register OldTypeReg = OldType.getReg();

OldResult.setReg(NewResultReg);
OldType.setReg(NewTypeReg);

MachineIRBuilder MIB(*I.getNextNode());
return MIB.buildInstr(SPIRV::OpCopyLogical)
.addDef(OldResultReg)
.addUse(OldTypeReg)
.addUse(NewResultReg)
.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
*STI.getRegBankInfo());
}
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class SPIRVTargetLowering : public TargetLowering {
EVT ConditionVT) const override {
return ConditionVT.getSimpleVT();
}

bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx,
unsigned OpIdx) const;
bool insertLogicalCopyOnResult(MachineInstr &I,
SPIRVType *NewResultType) const;
};
} // namespace llvm

Expand Down
11 changes: 6 additions & 5 deletions llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handle

; CHECK: OpDecorate [[BufferVar:%.+]] DescriptorSet 0
; CHECK: OpDecorate [[BufferVar]] Binding 0
; CHECK: OpDecorate [[BufferType:%.+]] Block
; CHECK: OpMemberDecorate [[BufferType]] 0 Offset 0
; CHECK: OpMemberDecorate [[BufferType:%.+]] 0 Offset 0
; CHECK: OpDecorate [[BufferType]] Block
; CHECK: OpMemberDecorate [[BufferType]] 0 NonWritable
; CHECK: OpDecorate [[RWBufferVar:%.+]] DescriptorSet 0
; CHECK: OpDecorate [[RWBufferVar]] Binding 1
; CHECK: OpDecorate [[RWBufferType:%.+]] Block
; CHECK: OpMemberDecorate [[RWBufferType]] 0 Offset 0
; CHECK: OpDecorate [[ArrayType:%.+]] ArrayStride 4
; CHECK: OpMemberDecorate [[RWBufferType:%.+]] 0 Offset 0
; CHECK: OpDecorate [[RWBufferType]] Block


; CHECK: [[int:%[0-9]+]] = OpTypeInt 32 0
; CHECK: [[ArrayType:%.+]] = OpTypeRuntimeArray
; CHECK: [[ArrayType]] = OpTypeRuntimeArray
; CHECK: [[RWBufferType]] = OpTypeStruct [[ArrayType]]
; CHECK: [[RWBufferPtrType:%.+]] = OpTypePointer StorageBuffer [[RWBufferType]]
; CHECK: [[BufferType]] = OpTypeStruct [[ArrayType]]
Expand Down
Loading