Skip to content

Commit 6c4ae27

Browse files
committed
[SPIR-V] Add support for inline SPIR-V types
Using HLSL's [Inline SPIR-V](https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html) features, users have the ability to use [`SpirvType`](https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html#types) to have fine-grained control over the SPIR-V representation of a type. As explained in the spec, this is useful because it enables vendors to author headers with types for their own extensions. As discussed in [Target Extension Types for Inline SPIR-V and Decorated Types](llvm/wg-hlsl#105), we would like to represent the HLSL SpirvType type using a 'spirv.Type' target extension type in LLVM IR. This pull request lowers that type to SPIR-V.
1 parent ae7f7c4 commit 6c4ae27

File tree

11 files changed

+230
-33
lines changed

11 files changed

+230
-33
lines changed

llvm/lib/IR/Type.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,26 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
968968
if (Name == "spirv.Image")
969969
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::CanBeGlobal,
970970
TargetExtType::CanBeLocal);
971+
if (Name == "spirv.Type") {
972+
assert(Ty->getNumIntParameters() == 3 &&
973+
"Wrong number of parameters for spirv.Type");
974+
975+
auto Size = Ty->getIntParameter(1);
976+
auto Alignment = Ty->getIntParameter(2);
977+
978+
// LLVM expects variables that can be allocated to have an alignment and
979+
// size. Default to using a 32-bit int as the layout type if none are
980+
// present.
981+
llvm::Type *LayoutType = Type::getInt32Ty(C);
982+
if (Size > 0 && Alignment > 0)
983+
LayoutType =
984+
ArrayType::get(Type::getIntNTy(C, Alignment), Size * 8 / Alignment);
985+
986+
return TargetTypeInfo(LayoutType, TargetExtType::CanBeGlobal,
987+
TargetExtType::CanBeLocal);
988+
}
989+
if (Name == "spirv.IntegralConstant" || Name == "spirv.Literal")
990+
return TargetTypeInfo(Type::getVoidTy(C));
971991
if (Name.starts_with("spirv."))
972992
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::HasZeroInit,
973993
TargetExtType::CanBeGlobal,

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
116116
recordOpExtInstImport(MI);
117117
} else if (OpCode == SPIRV::OpExtInst) {
118118
printOpExtInst(MI, OS);
119+
} else if (OpCode == SPIRV::UNKNOWN_type) {
120+
printUnknownType(MI, OS);
119121
} else {
120122
// Print any extra operands for variadic instructions.
121123
const MCInstrDesc &MCDesc = MII.get(OpCode);
@@ -314,6 +316,35 @@ void SPIRVInstPrinter::printOpDecorate(const MCInst *MI, raw_ostream &O) {
314316
}
315317
}
316318

319+
void SPIRVInstPrinter::printUnknownType(const MCInst *MI, raw_ostream &O) {
320+
const auto EnumOperand = MI->getOperand(1);
321+
assert(EnumOperand.isImm() &&
322+
"second operand of UNKNOWN_type must be opcode!");
323+
324+
const auto Enumerant = EnumOperand.getImm();
325+
const auto NumOps = MI->getNumOperands();
326+
327+
// Encode the instruction enumerant and word count into the opcode
328+
const auto OpCode = (0xFF & NumOps) << 16 | (0xFF & Enumerant);
329+
330+
// Print the opcode using the spirv-as arbitrary integer syntax
331+
// https://github.com/KhronosGroup/SPIRV-Tools/blob/main/docs/syntax.md#arbitrary-integers
332+
O << "!0x" << Twine::utohexstr(OpCode) << " ";
333+
334+
// The result ID must be printed after the opcode when using this syntax
335+
printOperand(MI, 0, O);
336+
337+
O << " ";
338+
339+
const MCInstrDesc &MCDesc = MII.get(MI->getOpcode());
340+
unsigned NumFixedOps = MCDesc.getNumOperands();
341+
if (NumOps == NumFixedOps)
342+
return;
343+
344+
// Print the rest of the operands
345+
printRemainingVariableOps(MI, NumFixedOps, O, true);
346+
}
347+
317348
static void printExpr(const MCExpr *Expr, raw_ostream &O) {
318349
#ifndef NDEBUG
319350
const MCSymbolRefExpr *SRE;

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class SPIRVInstPrinter : public MCInstPrinter {
3535

3636
void printOpDecorate(const MCInst *MI, raw_ostream &O);
3737
void printOpExtInst(const MCInst *MI, raw_ostream &O);
38+
void printUnknownType(const MCInst *MI, raw_ostream &O);
3839
void printRemainingVariableOps(const MCInst *MI, unsigned StartIndex,
3940
raw_ostream &O, bool SkipFirstSpace = false,
4041
bool SkipImmediates = false);

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class SPIRVMCCodeEmitter : public MCCodeEmitter {
4646
void encodeInstruction(const MCInst &MI, SmallVectorImpl<char> &CB,
4747
SmallVectorImpl<MCFixup> &Fixups,
4848
const MCSubtargetInfo &STI) const override;
49+
void encodeUnknownType(const MCInst &MI, SmallVectorImpl<char> &CB,
50+
SmallVectorImpl<MCFixup> &Fixups,
51+
const MCSubtargetInfo &STI) const;
4952
};
5053

5154
} // end anonymous namespace
@@ -104,10 +107,30 @@ static void emitUntypedInstrOperands(const MCInst &MI,
104107
emitOperand(Op, CB);
105108
}
106109

110+
void SPIRVMCCodeEmitter::encodeUnknownType(const MCInst &MI,
111+
SmallVectorImpl<char> &CB,
112+
SmallVectorImpl<MCFixup> &Fixups,
113+
const MCSubtargetInfo &STI) const {
114+
// Encode the first 32 SPIR-V bytes with the number of args and the opcode.
115+
const uint64_t OpCode = MI.getOperand(1).getImm();
116+
const uint32_t NumWords = MI.getNumOperands();
117+
const uint32_t FirstWord = (NumWords << 16) | OpCode;
118+
support::endian::write(CB, FirstWord, llvm::endianness::little);
119+
120+
emitOperand(MI.getOperand(0), CB);
121+
for (unsigned i = 2; i < NumWords; ++i)
122+
emitOperand(MI.getOperand(i), CB);
123+
}
124+
107125
void SPIRVMCCodeEmitter::encodeInstruction(const MCInst &MI,
108126
SmallVectorImpl<char> &CB,
109127
SmallVectorImpl<MCFixup> &Fixups,
110128
const MCSubtargetInfo &STI) const {
129+
if (MI.getOpcode() == SPIRV::UNKNOWN_type) {
130+
encodeUnknownType(MI, CB, Fixups, STI);
131+
return;
132+
}
133+
111134
// Encode the first 32 SPIR-V bytes with the number of args and the opcode.
112135
const uint64_t OpCode = getBinaryCodeForInstr(MI, Fixups, STI);
113136
const uint32_t NumWords = MI.getNumOperands() + 1;

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2868,6 +2868,61 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
28682868
return GR->getOrCreateOpTypeSampledImage(OpaqueImageType, MIRBuilder);
28692869
}
28702870

2871+
static SPIRVType *getInlineSpirvType(const TargetExtType *ExtensionType,
2872+
MachineIRBuilder &MIRBuilder,
2873+
SPIRVGlobalRegistry *GR) {
2874+
assert(ExtensionType->getNumIntParameters() == 3 &&
2875+
"Inline SPIR-V type builtin takes an opcode, size, and alignment "
2876+
"parameter");
2877+
auto Opcode = ExtensionType->getIntParameter(0);
2878+
2879+
return GR->getOrCreateUnknownType(
2880+
ExtensionType, MIRBuilder, Opcode,
2881+
[&ExtensionType, &GR, &MIRBuilder](llvm::MachineInstrBuilder Instr) {
2882+
for (llvm::Type *Param : ExtensionType->type_params()) {
2883+
if (const TargetExtType *ParamEType =
2884+
dyn_cast<TargetExtType>(Param)) {
2885+
if (ParamEType->getName() == "spirv.IntegralConstant") {
2886+
assert(ParamEType->getNumTypeParameters() == 1 &&
2887+
"Inline SPIR-V integral constant builtin must have a type "
2888+
"parameter");
2889+
assert(ParamEType->getNumIntParameters() == 1 &&
2890+
"Inline SPIR-V integral constant builtin must have a "
2891+
"value parameter");
2892+
2893+
auto OperandValue = ParamEType->getIntParameter(0);
2894+
auto *OperandType = ParamEType->getTypeParameter(0);
2895+
2896+
const SPIRVType *OperandSPIRVType =
2897+
GR->getOrCreateSPIRVType(OperandType, MIRBuilder);
2898+
2899+
Instr = Instr.addUse(GR->buildConstantInt(
2900+
OperandValue, MIRBuilder, OperandSPIRVType, true));
2901+
continue;
2902+
} else if (ParamEType->getName() == "spirv.Literal") {
2903+
assert(ParamEType->getNumTypeParameters() == 0 &&
2904+
"Inline SPIR-V literal builtin does not take type "
2905+
"parameters");
2906+
assert(ParamEType->getNumIntParameters() == 1 &&
2907+
"Inline SPIR-V literal builtin must have an integer "
2908+
"parameter");
2909+
2910+
auto OperandValue = ParamEType->getIntParameter(0);
2911+
2912+
Instr = Instr.addImm(OperandValue);
2913+
continue;
2914+
}
2915+
}
2916+
const SPIRVType *TypeOperand =
2917+
GR->getOrCreateSPIRVType(Param, MIRBuilder);
2918+
Instr = Instr.addUse(GR->getSPIRVTypeID(TypeOperand));
2919+
}
2920+
return Instr;
2921+
});
2922+
2923+
// GR->getOrCreateSPIRVArrayType();
2924+
}
2925+
28712926
namespace SPIRV {
28722927
TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
28732928
LLVMContext &Context) {
@@ -2940,39 +2995,45 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
29402995
const StringRef Name = BuiltinType->getName();
29412996
LLVM_DEBUG(dbgs() << "Lowering builtin type: " << Name << "\n");
29422997

2943-
// Lookup the demangled builtin type in the TableGen records.
2944-
const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
2945-
if (!TypeRecord)
2946-
report_fatal_error("Missing TableGen record for builtin type: " + Name);
2947-
2948-
// "Lower" the BuiltinType into TargetType. The following get<...>Type methods
2949-
// use the implementation details from TableGen records or TargetExtType
2950-
// parameters to either create a new OpType<...> machine instruction or get an
2951-
// existing equivalent SPIRVType from GlobalRegistry.
29522998
SPIRVType *TargetType;
2953-
switch (TypeRecord->Opcode) {
2954-
case SPIRV::OpTypeImage:
2955-
TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
2956-
break;
2957-
case SPIRV::OpTypePipe:
2958-
TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
2959-
break;
2960-
case SPIRV::OpTypeDeviceEvent:
2961-
TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
2962-
break;
2963-
case SPIRV::OpTypeSampler:
2964-
TargetType = getSamplerType(MIRBuilder, GR);
2965-
break;
2966-
case SPIRV::OpTypeSampledImage:
2967-
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
2968-
break;
2969-
case SPIRV::OpTypeCooperativeMatrixKHR:
2970-
TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
2971-
break;
2972-
default:
2973-
TargetType =
2974-
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
2975-
break;
2999+
if (Name == "spirv.Type") {
3000+
TargetType = getInlineSpirvType(BuiltinType, MIRBuilder, GR);
3001+
} else {
3002+
// Lookup the demangled builtin type in the TableGen records.
3003+
const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
3004+
if (!TypeRecord)
3005+
report_fatal_error("Missing TableGen record for builtin type: " + Name);
3006+
3007+
// "Lower" the BuiltinType into TargetType. The following get<...>Type
3008+
// methods use the implementation details from TableGen records or
3009+
// TargetExtType parameters to either create a new OpType<...> machine
3010+
// instruction or get an existing equivalent SPIRVType from
3011+
// GlobalRegistry.
3012+
3013+
switch (TypeRecord->Opcode) {
3014+
case SPIRV::OpTypeImage:
3015+
TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
3016+
break;
3017+
case SPIRV::OpTypePipe:
3018+
TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
3019+
break;
3020+
case SPIRV::OpTypeDeviceEvent:
3021+
TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
3022+
break;
3023+
case SPIRV::OpTypeSampler:
3024+
TargetType = getSamplerType(MIRBuilder, GR);
3025+
break;
3026+
case SPIRV::OpTypeSampledImage:
3027+
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
3028+
break;
3029+
case SPIRV::OpTypeCooperativeMatrixKHR:
3030+
TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
3031+
break;
3032+
default:
3033+
TargetType =
3034+
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
3035+
break;
3036+
}
29763037
}
29773038

29783039
// Emit OpName instruction if a new OpType<...> instruction was added

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,21 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
14061406
return SpirvTy;
14071407
}
14081408

1409+
SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
1410+
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
1411+
const std::function<llvm::MachineInstrBuilder(llvm::MachineInstrBuilder)>
1412+
&buildInstr) {
1413+
Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
1414+
if (ResVReg.isValid())
1415+
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
1416+
ResVReg = createTypeVReg(MIRBuilder);
1417+
SPIRVType *SpirvTy = buildInstr(MIRBuilder.buildInstr(SPIRV::UNKNOWN_type)
1418+
.addDef(ResVReg)
1419+
.addImm(Opcode));
1420+
DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
1421+
return SpirvTy;
1422+
}
1423+
14091424
const MachineInstr *
14101425
SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
14111426
MachineIRBuilder &MIRBuilder) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,11 @@ class SPIRVGlobalRegistry {
618618
MachineIRBuilder &MIRBuilder,
619619
unsigned Opcode);
620620

621+
SPIRVType *getOrCreateUnknownType(
622+
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
623+
const std::function<llvm::MachineInstrBuilder(llvm::MachineInstrBuilder)>
624+
&buildInstr);
625+
621626
const TargetRegisterClass *getRegClass(SPIRVType *SpvType) const;
622627
LLT getRegType(SPIRVType *SpvType) const;
623628
};

llvm/lib/Target/SPIRV/SPIRVInstrFormats.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
2525
let Pattern = pattern;
2626
}
2727

28+
class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
29+
: Op<0, outs, ins, asmstr, pattern> {
30+
let isPseudo = 1;
31+
}
32+
2833
// Pseudo instructions
2934
class Pseudo<dag outs, dag ins> : Op<0, outs, ins, ""> {
3035
let isPseudo = 1;

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ let isCodeGenOnly=1 in {
2525
def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins vpID:$src)>;
2626
}
2727

28+
def UNKNOWN_type
29+
: UnknownOp<(outs TYPE:$type), (ins i32imm:$opcode, variable_ops), " ">;
30+
2831
def SPVTypeBin : SDTypeProfile<1, 2, []>;
2932

3033
def assigntype : SDNode<"SPIRVISD::AssignType", SPVTypeBin>;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
495495
bool HasDefs = I.getNumDefs() > 0;
496496
Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0);
497497
SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr;
498-
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
498+
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
499+
I.getOpcode() == TargetOpcode::G_IMPLICIT_DEF);
499500
if (spvSelect(ResVReg, ResType, I)) {
500501
if (HasDefs) // Make all vregs 64 bits (for SPIR-V IDs).
501502
for (unsigned i = 0; i < I.getNumDefs(); ++i)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - | spirv-as - -o - | spirv-val %}
4+
5+
; CHECK: [[uint32_t:%[0-9]+]] = OpTypeInt 32 0
6+
7+
; CHECK: [[image_t:%[0-9]+]] = OpTypeImage %3 2D 2 0 0 1 Unknown
8+
%type_2d_image = type target("spirv.Image", float, 1, 2, 0, 0, 1, 0)
9+
10+
%literal_false = type target("spirv.Literal", 0)
11+
%literal_8 = type target("spirv.Literal", 8)
12+
13+
; CHECK: [[uint32_4:%[0-9]+]] = OpConstant [[uint32_t]] 4
14+
%integral_constant_4 = type target("spirv.IntegralConstant", i32, 4)
15+
16+
; CHECK: !0x4001c [[array_t:%[0-9]+]] [[image_t]] [[uint32_4]]
17+
%ArrayTex2D = type target("spirv.Type", %type_2d_image, %integral_constant_4, 28, 0, 0)
18+
19+
; CHECK: [[getTexArray_t:%[0-9]+]] = OpTypeFunction [[array_t]]
20+
21+
; CHECK: [[getTexArray:%[0-9]+]] = OpFunction [[array_t]] None [[getTexArray_t]]
22+
declare %ArrayTex2D @getTexArray()
23+
24+
define void @main() #1 {
25+
entry:
26+
%images = alloca %ArrayTex2D
27+
28+
; CHECK: {{%[0-9]+}} = OpFunctionCall [[array_t]] [[getTexArray]]
29+
%retTex = call %ArrayTex2D @getTexArray()
30+
31+
ret void
32+
}

0 commit comments

Comments
 (0)