Skip to content

[SPIRV] Add spirv.VulkanBuffer types to the backend #133475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,20 @@ using target extension types and are represented as follows:

.. table:: SPIR-V Opaque Types

================== ====================== ===========================================================================================
SPIR-V Type LLVM type name LLVM type arguments
================== ====================== ===========================================================================================
OpTypeImage ``spirv.Image`` sampled type, dimensionality, depth, arrayed, MS, sampled, image format, [access qualifier]
OpTypeSampler ``spirv.Sampler`` (none)
OpTypeSampledImage ``spirv.SampledImage`` sampled type, dimensionality, depth, arrayed, MS, sampled, image format, [access qualifier]
OpTypeEvent ``spirv.Event`` (none)
OpTypeDeviceEvent ``spirv.DeviceEvent`` (none)
OpTypeReserveId ``spirv.ReserveId`` (none)
OpTypeQueue ``spirv.Queue`` (none)
OpTypePipe ``spirv.Pipe`` access qualifier
OpTypePipeStorage ``spirv.PipeStorage`` (none)
================== ====================== ===========================================================================================
================== ======================= ===========================================================================================
SPIR-V Type LLVM type name LLVM type arguments
================== ======================= ===========================================================================================
OpTypeImage ``spirv.Image`` sampled type, dimensionality, depth, arrayed, MS, sampled, image format, [access qualifier]
OpTypeSampler ``spirv.Sampler`` (none)
OpTypeSampledImage ``spirv.SampledImage`` sampled type, dimensionality, depth, arrayed, MS, sampled, image format, [access qualifier]
OpTypeEvent ``spirv.Event`` (none)
OpTypeDeviceEvent ``spirv.DeviceEvent`` (none)
OpTypeReserveId ``spirv.ReserveId`` (none)
OpTypeQueue ``spirv.Queue`` (none)
OpTypePipe ``spirv.Pipe`` access qualifier
OpTypePipeStorage ``spirv.PipeStorage`` (none)
NA ``spirv.VulkanBuffer`` ElementType, StorageClass, IsWriteable
================== ======================= ===========================================================================================

All integer arguments take the same value as they do in their `corresponding
SPIR-V instruction <https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_type_declaration_instructions>`_.
Expand All @@ -266,6 +267,9 @@ parameters of its underlying image type, so that a sampled image for the
previous type has the representation
``target("spirv.SampledImage, void, 1, 1, 0, 0, 0, 0, 0)``.

See `wg-hlsl proposal 0018 <https://github.com/llvm/wg-hlsl/blob/main/proposals/0018-spirv-resource-representation.md>`_
for details on ``spirv.VulkanBuffer``.

.. _inline-spirv-types:

Inline SPIR-V Types
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
printRemainingVariableOps(MI, NumFixedOps, OS, false, true);
break;
}
case SPIRV::OpMemberDecorate:
printRemainingVariableOps(MI, NumFixedOps, OS);
break;
case SPIRV::OpExecutionMode:
case SPIRV::OpExecutionModeId:
case SPIRV::OpLoopMerge: {
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3090,6 +3090,22 @@ static SPIRVType *getInlineSpirvType(const TargetExtType *ExtensionType,
Operands);
}

static SPIRVType *getVulkanBufferType(const TargetExtType *ExtensionType,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
assert(ExtensionType->getNumTypeParameters() == 1 &&
"Vulkan buffers have exactly one type for the type of the buffer.");
assert(ExtensionType->getNumIntParameters() == 2 &&
"Vulkan buffer have 2 integer parameters: storage class and is "
"writable.");

auto *T = ExtensionType->getTypeParameter(0);
auto SC = static_cast<SPIRV::StorageClass::StorageClass>(
ExtensionType->getIntParameter(0));
bool IsWritable = ExtensionType->getIntParameter(1);
return GR->getOrCreateVulkanBufferType(MIRBuilder, T, SC, IsWritable);
}

namespace SPIRV {
TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
LLVMContext &Context) {
Expand Down Expand Up @@ -3165,6 +3181,8 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
SPIRVType *TargetType;
if (Name == "spirv.Type") {
TargetType = getInlineSpirvType(BuiltinType, MIRBuilder, GR);
} else if (Name == "spirv.VulkanBuffer") {
TargetType = getVulkanBufferType(BuiltinType, MIRBuilder, GR);
} else {
// Lookup the demangled builtin type in the TableGen records.
const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
Expand Down
23 changes: 16 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,22 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(

auto *II = dyn_cast<IntrinsicInst>(I);
if (II && II->getIntrinsicID() == Intrinsic::spv_resource_getpointer) {
auto *ImageType = cast<TargetExtType>(II->getOperand(0)->getType());
assert(ImageType->getTargetExtName() == "spirv.Image");
(void)ImageType;
if (II->hasOneUse()) {
auto *U = *II->users().begin();
Ty = cast<Instruction>(U)->getAccessType();
assert(Ty && "Unable to get type for resource pointer.");
auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
if (HandleType->getTargetExtName() == "spirv.Image") {
if (II->hasOneUse()) {
auto *U = *II->users().begin();
Ty = cast<Instruction>(U)->getAccessType();
assert(Ty && "Unable to get type for resource pointer.");
}
} else if (HandleType->getTargetExtName() == "spirv.VulkanBuffer") {
// This call is supposed to index into an array
Ty = HandleType->getTypeParameter(0);
assert(Ty->isArrayTy() &&
"spv_resource_getpointer indexes into an array, so the type of "
"the buffer should be an array.");
Ty = Ty->getArrayElementType();
} else {
llvm_unreachable("Unknown handle type for spv_resource_getpointer.");
}
} else if (Function *CalledF = CI->getCalledFunction()) {
std::string DemangledName =
Expand Down
96 changes: 76 additions & 20 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,23 +767,25 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(

static std::string GetSpirvImageTypeName(const SPIRVType *Type,
MachineIRBuilder &MIRBuilder,
const std::string &Prefix);
const std::string &Prefix,
SPIRVGlobalRegistry &GR);

static std::string buildSpirvTypeName(const SPIRVType *Type,
MachineIRBuilder &MIRBuilder) {
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry &GR) {
switch (Type->getOpcode()) {
case SPIRV::OpTypeSampledImage: {
return GetSpirvImageTypeName(Type, MIRBuilder, "sampled_image_");
return GetSpirvImageTypeName(Type, MIRBuilder, "sampled_image_", GR);
}
case SPIRV::OpTypeImage: {
return GetSpirvImageTypeName(Type, MIRBuilder, "image_");
return GetSpirvImageTypeName(Type, MIRBuilder, "image_", GR);
}
case SPIRV::OpTypeArray: {
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register ElementTypeReg = Type->getOperand(1).getReg();
auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg);
uint32_t ArraySize = getArrayComponentCount(MRI, Type);
return (buildSpirvTypeName(ElementType, MIRBuilder) + Twine("[") +
return (buildSpirvTypeName(ElementType, MIRBuilder, GR) + Twine("[") +
Twine(ArraySize) + Twine("]"))
.str();
}
Expand All @@ -795,17 +797,35 @@ static std::string buildSpirvTypeName(const SPIRVType *Type,
if (Type->getOperand(2).getImm())
return ("i" + Twine(Type->getOperand(1).getImm())).str();
return ("u" + Twine(Type->getOperand(1).getImm())).str();
case SPIRV::OpTypePointer: {
uint32_t StorageClass = GR.getPointerStorageClass(Type);
SPIRVType *PointeeType = GR.getPointeeType(Type);
return ("p_" + Twine(StorageClass) + Twine("_") +
buildSpirvTypeName(PointeeType, MIRBuilder, GR))
.str();
}
case SPIRV::OpTypeStruct: {
std::string TypeName = "{";
for (uint32_t I = 2; I < Type->getNumOperands(); ++I) {
SPIRVType *MemberType =
GR.getSPIRVTypeForVReg(Type->getOperand(I).getReg());
TypeName = '_' + buildSpirvTypeName(MemberType, MIRBuilder, GR);
}
return TypeName + "}";
}
default:
llvm_unreachable("Trying to the the name of an unknown type.");
}
}

static std::string GetSpirvImageTypeName(const SPIRVType *Type,
MachineIRBuilder &MIRBuilder,
const std::string &Prefix) {
const std::string &Prefix,
SPIRVGlobalRegistry &GR) {
Register SampledTypeReg = Type->getOperand(1).getReg();
auto *SampledType = MIRBuilder.getMRI()->getUniqueVRegDef(SampledTypeReg);
std::string TypeName = Prefix + buildSpirvTypeName(SampledType, MIRBuilder);
std::string TypeName =
Prefix + buildSpirvTypeName(SampledType, MIRBuilder, GR);
for (uint32_t I = 2; I < Type->getNumOperands(); ++I) {
TypeName = (TypeName + '_' + Twine(Type->getOperand(I).getImm())).str();
}
Expand All @@ -815,20 +835,19 @@ static std::string GetSpirvImageTypeName(const SPIRVType *Type,
Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding(
const SPIRVType *VarType, uint32_t Set, uint32_t Binding,
MachineIRBuilder &MIRBuilder) {
SPIRVType *VarPointerTypeReg = getOrCreateSPIRVPointerType(
VarType, MIRBuilder, SPIRV::StorageClass::UniformConstant);
Register VarReg =
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);

// TODO: The name should come from the llvm-ir, but how that name will be
// passed from the HLSL to the backend has not been decided. Using this place
// holder for now.
std::string Name = ("__resource_" + buildSpirvTypeName(VarType, MIRBuilder) +
"_" + Twine(Set) + "_" + Twine(Binding))
.str();
buildGlobalVariable(VarReg, VarPointerTypeReg, Name, nullptr,
SPIRV::StorageClass::UniformConstant, nullptr, false,
false, SPIRV::LinkageType::Import, MIRBuilder, false);
std::string Name =
("__resource_" + buildSpirvTypeName(VarType, MIRBuilder, *this) + "_" +
Twine(Set) + "_" + Twine(Binding))
.str();
buildGlobalVariable(VarReg, VarType, Name, nullptr,
getPointerStorageClass(VarType), nullptr, false, false,
SPIRV::LinkageType::Import, MIRBuilder, false);

buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set});
buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding});
Expand All @@ -842,13 +861,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
"Invalid array element type");
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
Register NumElementsVReg =
buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);

if (NumElems != 0) {
Register NumElementsVReg =
buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
.addUse(NumElementsVReg);
});
}

return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
return MIRBuilder.buildInstr(SPIRV::OpTypeRuntimeArray)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
.addUse(NumElementsVReg);
.addUse(getSPIRVTypeID(ElemType));
});
}

Expand Down Expand Up @@ -1296,6 +1324,34 @@ SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
Type->getOperand(1).getImm());
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
MachineIRBuilder &MIRBuilder, Type *ElemType,
SPIRV::StorageClass::StorageClass SC, bool IsWritable, bool EmitIr) {
auto Key = SPIRV::irhandle_vkbuffer(ElemType, SC, IsWritable);
if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
return MI;

// TODO(134119): The SPIRVType for `ElemType` will not have an explicit
// layout. This generates invalid SPIR-V.
auto *T = StructType::create(ElemType);
auto *BlockType =
getOrCreateSPIRVType(T, MIRBuilder, SPIRV::AccessQualifier::None, EmitIr);

buildOpDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
SPIRV::Decoration::Block, {});
buildOpMemberDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
SPIRV::Decoration::Offset, 0, {0});

if (!IsWritable) {
buildOpMemberDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
SPIRV::Decoration::NonWritable, 0, {});
}

SPIRVType *R = getOrCreateSPIRVPointerType(BlockType, MIRBuilder, SC);
add(Key, R);
return R;
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);

SPIRVType *getOrCreateVulkanBufferType(MachineIRBuilder &MIRBuilder,
Type *ElemType,
SPIRV::StorageClass::StorageClass SC,
bool IsWritable, bool EmitIr = false);

SPIRVType *
getOrCreateOpTypeImage(MachineIRBuilder &MIRBuilder, SPIRVType *SampledType,
SPIRV::Dim::Dim Dim, uint32_t Depth, uint32_t Arrayed,
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVIRMapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ enum SpecialTypeKind {
STK_Type,
STK_Value,
STK_MachineInstr,
STK_VkBuffer,
STK_Last = -1
};

Expand Down Expand Up @@ -142,6 +143,13 @@ inline IRHandle irhandle_ptr(const void *Ptr, unsigned Arg,
return std::make_tuple(Ptr, Arg, STK);
}

inline IRHandle irhandle_vkbuffer(const Type *ElementType,
StorageClass::StorageClass SC,
bool IsWriteable) {
return std::make_tuple(ElementType, (SC << 1) | IsWriteable,
SpecialTypeKind::STK_VkBuffer);
}

inline IRHandle handle(const Type *Ty) {
const Type *WrpTy = unifyPtrType(Ty);
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_Type);
Expand Down
Loading