Skip to content

Commit 864a83d

Browse files
[SPIR-V] Add support for inline SPIR-V types (#125316)
Using HLSL's [Inline SPIR-V](https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html) features, users have the ability to use [`SpirvType`](https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html#types) to have fine-grained control over the SPIR-V representation of a type. As explained in the spec, this is useful because it enables vendors to author headers with types for their own extensions. As discussed in [Target Extension Types for Inline SPIR-V and Decorated Types](llvm/wg-hlsl#105), we would like to represent the HLSL SpirvType type using a 'spirv.Type' target extension type in LLVM IR. This pull request lowers that type to SPIR-V.
1 parent dd191d3 commit 864a83d

13 files changed

+324
-33
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,54 @@ parameters of its underlying image type, so that a sampled image for the
266266
previous type has the representation
267267
``target("spirv.SampledImage, void, 1, 1, 0, 0, 0, 0, 0)``.
268268

269+
.. _inline-spirv-types:
270+
271+
Inline SPIR-V Types
272+
-------------------
273+
274+
HLSL allows users to create types representing specific SPIR-V types, using ``vk::SpirvType`` and
275+
``vk::SpirvOpaqueType``. These are specified in the `Inline SPIR-V`_ proposal. They may be
276+
represented using target extension types:
277+
278+
.. _Inline SPIR-V: https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html#types
279+
280+
.. table:: Inline SPIR-V Types
281+
282+
========================== =================== =========================
283+
LLVM type name LLVM type arguments LLVM integer arguments
284+
========================== =================== =========================
285+
``spirv.Type`` SPIR-V operands opcode, size, alignment
286+
``spirv.IntegralConstant`` integral type value
287+
``spirv.Literal`` (none) value
288+
========================== =================== =========================
289+
290+
The operand arguments to ``spirv.Type`` may be either a ``spirv.IntegralConstant`` type,
291+
representing an ``OpConstant`` id operand, a ``spirv.Literal`` type, representing an immediate
292+
literal operand, or any other type, representing the id of that type as an operand.
293+
``spirv.IntegralConstant`` and ``spirv.Literal`` may not be used outside of this context.
294+
295+
For example, ``OpTypeArray`` (opcode 28) takes an id for the element type and an id for the element
296+
length, so an array of 16 integers could be declared as:
297+
298+
``target("spirv.Type", i32, target("spirv.IntegralConstant", i32, 16), 28, 64, 32)``
299+
300+
This will be lowered to:
301+
302+
``OpTypeArray %int %int_16``
303+
304+
``OpTypeVector`` takes an id for the component type and a literal for the component count, so a
305+
4-integer vector could be declared as:
306+
307+
``target("spirv.Type", i32, target("spirv.Literal", 4), 23, 16, 32)``
308+
309+
This will be lowered to:
310+
311+
``OpTypeVector %int 4``
312+
313+
See `Target Extension Types for Inline SPIR-V and Decorated Types`_ for further details.
314+
315+
.. _Target Extension Types for Inline SPIR-V and Decorated Types: https://github.com/llvm/wg-hlsl/blob/main/proposals/0017-inline-spirv-and-decorated-types.md
316+
269317
.. _spirv-intrinsics:
270318

271319
Target Intrinsics

llvm/lib/IR/Type.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,29 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
968968
if (Name == "spirv.Image")
969969
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::CanBeGlobal,
970970
TargetExtType::CanBeLocal);
971+
if (Name == "spirv.Type") {
972+
assert(Ty->getNumIntParameters() == 3 &&
973+
"Wrong number of parameters for spirv.Type");
974+
975+
auto Size = Ty->getIntParameter(1);
976+
auto Alignment = Ty->getIntParameter(2);
977+
978+
llvm::Type *LayoutType = nullptr;
979+
if (Size > 0 && Alignment > 0) {
980+
LayoutType =
981+
ArrayType::get(Type::getIntNTy(C, Alignment), Size * 8 / Alignment);
982+
} else {
983+
// LLVM expects variables that can be allocated to have an alignment and
984+
// size. Default to using a 32-bit int as the layout type if none are
985+
// present.
986+
LayoutType = Type::getInt32Ty(C);
987+
}
988+
989+
return TargetTypeInfo(LayoutType, TargetExtType::CanBeGlobal,
990+
TargetExtType::CanBeLocal);
991+
}
992+
if (Name == "spirv.IntegralConstant" || Name == "spirv.Literal")
993+
return TargetTypeInfo(Type::getVoidTy(C));
971994
if (Name.starts_with("spirv."))
972995
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::HasZeroInit,
973996
TargetExtType::CanBeGlobal,

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
114114
recordOpExtInstImport(MI);
115115
} else if (OpCode == SPIRV::OpExtInst) {
116116
printOpExtInst(MI, OS);
117+
} else if (OpCode == SPIRV::UNKNOWN_type) {
118+
printUnknownType(MI, OS);
117119
} else {
118120
// Print any extra operands for variadic instructions.
119121
const MCInstrDesc &MCDesc = MII.get(OpCode);
@@ -312,6 +314,31 @@ void SPIRVInstPrinter::printOpDecorate(const MCInst *MI, raw_ostream &O) {
312314
}
313315
}
314316

317+
void SPIRVInstPrinter::printUnknownType(const MCInst *MI, raw_ostream &O) {
318+
const auto EnumOperand = MI->getOperand(1);
319+
assert(EnumOperand.isImm() &&
320+
"second operand of UNKNOWN_type must be opcode!");
321+
322+
const auto Enumerant = EnumOperand.getImm();
323+
const auto NumOps = MI->getNumOperands();
324+
325+
// Print the opcode using the spirv-as unknown opcode syntax
326+
O << "OpUnknown(" << Enumerant << ", " << NumOps << ") ";
327+
328+
// The result ID must be printed after the opcode when using this syntax
329+
printOperand(MI, 0, O);
330+
331+
O << " ";
332+
333+
const MCInstrDesc &MCDesc = MII.get(MI->getOpcode());
334+
unsigned NumFixedOps = MCDesc.getNumOperands();
335+
if (NumOps == NumFixedOps)
336+
return;
337+
338+
// Print the rest of the operands
339+
printRemainingVariableOps(MI, NumFixedOps, O, true);
340+
}
341+
315342
static void printExpr(const MCExpr *Expr, raw_ostream &O) {
316343
#ifndef NDEBUG
317344
const MCSymbolRefExpr *SRE;

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class SPIRVInstPrinter : public MCInstPrinter {
3535

3636
void printOpDecorate(const MCInst *MI, raw_ostream &O);
3737
void printOpExtInst(const MCInst *MI, raw_ostream &O);
38+
void printUnknownType(const MCInst *MI, raw_ostream &O);
3839
void printRemainingVariableOps(const MCInst *MI, unsigned StartIndex,
3940
raw_ostream &O, bool SkipFirstSpace = false,
4041
bool SkipImmediates = false);

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class SPIRVMCCodeEmitter : public MCCodeEmitter {
4545
void encodeInstruction(const MCInst &MI, SmallVectorImpl<char> &CB,
4646
SmallVectorImpl<MCFixup> &Fixups,
4747
const MCSubtargetInfo &STI) const override;
48+
void encodeUnknownType(const MCInst &MI, SmallVectorImpl<char> &CB,
49+
SmallVectorImpl<MCFixup> &Fixups,
50+
const MCSubtargetInfo &STI) const;
4851
};
4952

5053
} // end anonymous namespace
@@ -104,10 +107,32 @@ static void emitUntypedInstrOperands(const MCInst &MI,
104107
emitOperand(Op, CB);
105108
}
106109

110+
void SPIRVMCCodeEmitter::encodeUnknownType(const MCInst &MI,
111+
SmallVectorImpl<char> &CB,
112+
SmallVectorImpl<MCFixup> &Fixups,
113+
const MCSubtargetInfo &STI) const {
114+
// Encode the first 32 SPIR-V bits with the number of args and the opcode.
115+
const uint64_t OpCode = MI.getOperand(1).getImm();
116+
const uint32_t NumWords = MI.getNumOperands();
117+
const uint32_t FirstWord = (0xFFFF & NumWords) << 16 | (0xFFFF & OpCode);
118+
119+
// encoding: <opcode+len> <result type> [<operand0> <operand1> ...]
120+
support::endian::write(CB, FirstWord, llvm::endianness::little);
121+
122+
emitOperand(MI.getOperand(0), CB);
123+
for (unsigned i = 2; i < NumWords; ++i)
124+
emitOperand(MI.getOperand(i), CB);
125+
}
126+
107127
void SPIRVMCCodeEmitter::encodeInstruction(const MCInst &MI,
108128
SmallVectorImpl<char> &CB,
109129
SmallVectorImpl<MCFixup> &Fixups,
110130
const MCSubtargetInfo &STI) const {
131+
if (MI.getOpcode() == SPIRV::UNKNOWN_type) {
132+
encodeUnknownType(MI, CB, Fixups, STI);
133+
return;
134+
}
135+
111136
// Encode the first 32 SPIR-V bytes with the number of args and the opcode.
112137
const uint64_t OpCode = getBinaryCodeForInstr(MI, Fixups, STI);
113138
const uint32_t NumWords = MI.getNumOperands() + 1;

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3041,6 +3041,57 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
30413041
return GR->getOrCreateOpTypeSampledImage(OpaqueImageType, MIRBuilder);
30423042
}
30433043

3044+
static SPIRVType *getInlineSpirvType(const TargetExtType *ExtensionType,
3045+
MachineIRBuilder &MIRBuilder,
3046+
SPIRVGlobalRegistry *GR) {
3047+
assert(ExtensionType->getNumIntParameters() == 3 &&
3048+
"Inline SPIR-V type builtin takes an opcode, size, and alignment "
3049+
"parameter");
3050+
auto Opcode = ExtensionType->getIntParameter(0);
3051+
3052+
SmallVector<MCOperand> Operands;
3053+
for (llvm::Type *Param : ExtensionType->type_params()) {
3054+
if (const TargetExtType *ParamEType = dyn_cast<TargetExtType>(Param)) {
3055+
if (ParamEType->getName() == "spirv.IntegralConstant") {
3056+
assert(ParamEType->getNumTypeParameters() == 1 &&
3057+
"Inline SPIR-V integral constant builtin must have a type "
3058+
"parameter");
3059+
assert(ParamEType->getNumIntParameters() == 1 &&
3060+
"Inline SPIR-V integral constant builtin must have a "
3061+
"value parameter");
3062+
3063+
auto OperandValue = ParamEType->getIntParameter(0);
3064+
auto *OperandType = ParamEType->getTypeParameter(0);
3065+
3066+
const SPIRVType *OperandSPIRVType = GR->getOrCreateSPIRVType(
3067+
OperandType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
3068+
3069+
Operands.push_back(MCOperand::createReg(GR->buildConstantInt(
3070+
OperandValue, MIRBuilder, OperandSPIRVType, true)));
3071+
continue;
3072+
} else if (ParamEType->getName() == "spirv.Literal") {
3073+
assert(ParamEType->getNumTypeParameters() == 0 &&
3074+
"Inline SPIR-V literal builtin does not take type "
3075+
"parameters");
3076+
assert(ParamEType->getNumIntParameters() == 1 &&
3077+
"Inline SPIR-V literal builtin must have an integer "
3078+
"parameter");
3079+
3080+
auto OperandValue = ParamEType->getIntParameter(0);
3081+
3082+
Operands.push_back(MCOperand::createImm(OperandValue));
3083+
continue;
3084+
}
3085+
}
3086+
const SPIRVType *TypeOperand = GR->getOrCreateSPIRVType(
3087+
Param, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
3088+
Operands.push_back(MCOperand::createReg(GR->getSPIRVTypeID(TypeOperand)));
3089+
}
3090+
3091+
return GR->getOrCreateUnknownType(ExtensionType, MIRBuilder, Opcode,
3092+
Operands);
3093+
}
3094+
30443095
namespace SPIRV {
30453096
TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
30463097
LLVMContext &Context) {
@@ -3113,39 +3164,45 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
31133164
const StringRef Name = BuiltinType->getName();
31143165
LLVM_DEBUG(dbgs() << "Lowering builtin type: " << Name << "\n");
31153166

3116-
// Lookup the demangled builtin type in the TableGen records.
3117-
const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
3118-
if (!TypeRecord)
3119-
report_fatal_error("Missing TableGen record for builtin type: " + Name);
3120-
3121-
// "Lower" the BuiltinType into TargetType. The following get<...>Type methods
3122-
// use the implementation details from TableGen records or TargetExtType
3123-
// parameters to either create a new OpType<...> machine instruction or get an
3124-
// existing equivalent SPIRVType from GlobalRegistry.
31253167
SPIRVType *TargetType;
3126-
switch (TypeRecord->Opcode) {
3127-
case SPIRV::OpTypeImage:
3128-
TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
3129-
break;
3130-
case SPIRV::OpTypePipe:
3131-
TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
3132-
break;
3133-
case SPIRV::OpTypeDeviceEvent:
3134-
TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
3135-
break;
3136-
case SPIRV::OpTypeSampler:
3137-
TargetType = getSamplerType(MIRBuilder, GR);
3138-
break;
3139-
case SPIRV::OpTypeSampledImage:
3140-
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
3141-
break;
3142-
case SPIRV::OpTypeCooperativeMatrixKHR:
3143-
TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
3144-
break;
3145-
default:
3146-
TargetType =
3147-
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
3148-
break;
3168+
if (Name == "spirv.Type") {
3169+
TargetType = getInlineSpirvType(BuiltinType, MIRBuilder, GR);
3170+
} else {
3171+
// Lookup the demangled builtin type in the TableGen records.
3172+
const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
3173+
if (!TypeRecord)
3174+
report_fatal_error("Missing TableGen record for builtin type: " + Name);
3175+
3176+
// "Lower" the BuiltinType into TargetType. The following get<...>Type
3177+
// methods use the implementation details from TableGen records or
3178+
// TargetExtType parameters to either create a new OpType<...> machine
3179+
// instruction or get an existing equivalent SPIRVType from
3180+
// GlobalRegistry.
3181+
3182+
switch (TypeRecord->Opcode) {
3183+
case SPIRV::OpTypeImage:
3184+
TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
3185+
break;
3186+
case SPIRV::OpTypePipe:
3187+
TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
3188+
break;
3189+
case SPIRV::OpTypeDeviceEvent:
3190+
TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
3191+
break;
3192+
case SPIRV::OpTypeSampler:
3193+
TargetType = getSamplerType(MIRBuilder, GR);
3194+
break;
3195+
case SPIRV::OpTypeSampledImage:
3196+
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
3197+
break;
3198+
case SPIRV::OpTypeCooperativeMatrixKHR:
3199+
TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
3200+
break;
3201+
default:
3202+
TargetType =
3203+
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
3204+
break;
3205+
}
31493206
}
31503207

31513208
// Emit OpName instruction if a new OpType<...> instruction was added

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,32 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
14561456
return SpirvTy;
14571457
}
14581458

1459+
SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
1460+
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
1461+
const ArrayRef<MCOperand> Operands) {
1462+
Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
1463+
if (ResVReg.isValid())
1464+
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
1465+
ResVReg = createTypeVReg(MIRBuilder);
1466+
1467+
DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
1468+
1469+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1470+
MachineInstrBuilder MIB = MIRBuilder.buildInstr(SPIRV::UNKNOWN_type)
1471+
.addDef(ResVReg)
1472+
.addImm(Opcode);
1473+
for (MCOperand Operand : Operands) {
1474+
if (Operand.isReg()) {
1475+
MIB.addUse(Operand.getReg());
1476+
} else if (Operand.isImm()) {
1477+
MIB.addImm(Operand.getImm());
1478+
}
1479+
}
1480+
1481+
return MIB;
1482+
});
1483+
}
1484+
14591485
const MachineInstr *
14601486
SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
14611487
MachineIRBuilder &MIRBuilder) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,11 @@ class SPIRVGlobalRegistry {
621621
MachineIRBuilder &MIRBuilder,
622622
unsigned Opcode);
623623

624+
SPIRVType *getOrCreateUnknownType(const Type *Ty,
625+
MachineIRBuilder &MIRBuilder,
626+
unsigned Opcode,
627+
const ArrayRef<MCOperand> Operands);
628+
624629
const TargetRegisterClass *getRegClass(SPIRVType *SpvType) const;
625630
LLT getRegType(SPIRVType *SpvType) const;
626631

llvm/lib/Target/SPIRV/SPIRVInstrFormats.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
2525
let Pattern = pattern;
2626
}
2727

28+
class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
29+
: Op<0, outs, ins, asmstr, pattern> {
30+
let isPseudo = 1;
31+
}
32+
2833
// Pseudo instructions
2934
class Pseudo<dag outs, dag ins> : Op<0, outs, ins, ""> {
3035
let isPseudo = 1;

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ let isCodeGenOnly=1 in {
2525
def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins vpID:$src)>;
2626
}
2727

28+
def UNKNOWN_type
29+
: UnknownOp<(outs TYPE:$type), (ins i32imm:$opcode, variable_ops), " ">;
30+
2831
def SPVTypeBin : SDTypeProfile<1, 2, []>;
2932

3033
def assigntype : SDNode<"SPIRVISD::AssignType", SPVTypeBin>;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
496496
bool HasDefs = I.getNumDefs() > 0;
497497
Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0);
498498
SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr;
499-
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
499+
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
500+
I.getOpcode() == TargetOpcode::G_IMPLICIT_DEF);
500501
if (spvSelect(ResVReg, ResType, I)) {
501502
if (HasDefs) // Make all vregs 64 bits (for SPIR-V IDs).
502503
for (unsigned i = 0; i < I.getNumDefs(); ++i)

0 commit comments

Comments
 (0)