Skip to content

Commit 4a602d9

Browse files
Add support for the SPV_INTEL_usm_storage_classes extension (#82247)
Add support for the SPV_INTEL_usm_storage_classes extension: * https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_usm_storage_classes.asciidoc
1 parent 6193233 commit 4a602d9

File tree

12 files changed

+183
-28
lines changed

12 files changed

+183
-28
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
150150

151151
static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
152152
SPIRVGlobalRegistry *GR,
153-
MachineIRBuilder &MIRBuilder) {
153+
MachineIRBuilder &MIRBuilder,
154+
const SPIRVSubtarget &ST) {
154155
// Read argument's access qualifier from metadata or default.
155156
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
156157
getArgAccessQual(F, ArgIdx);
@@ -169,8 +170,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
169170
if (MDTypeStr.ends_with("*"))
170171
ResArgType = GR->getOrCreateSPIRVTypeByName(
171172
MDTypeStr, MIRBuilder,
172-
addressSpaceToStorageClass(
173-
OriginalArgType->getPointerAddressSpace()));
173+
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
174+
ST));
174175
else if (MDTypeStr.ends_with("_t"))
175176
ResArgType = GR->getOrCreateSPIRVTypeByName(
176177
"opencl." + MDTypeStr.str(), MIRBuilder,
@@ -206,6 +207,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
206207
assert(GR && "Must initialize the SPIRV type registry before lowering args.");
207208
GR->setCurrentFunc(MIRBuilder.getMF());
208209

210+
// Get access to information about available extensions
211+
const SPIRVSubtarget *ST =
212+
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
213+
209214
// Assign types and names to all args, and store their types for later.
210215
FunctionType *FTy = getOriginalFunctionType(F);
211216
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
@@ -216,7 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
216221
// TODO: handle the case of multiple registers.
217222
if (VRegs[i].size() > 1)
218223
return false;
219-
auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
224+
auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
220225
GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
221226
ArgTypeVRegs.push_back(SpirvTy);
222227

@@ -318,10 +323,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
318323
if (F.hasName())
319324
buildOpName(FuncVReg, F.getName(), MIRBuilder);
320325

321-
// Get access to information about available extensions
322-
const auto *ST =
323-
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
324-
325326
// Handle entry points and function linkage.
326327
if (isEntryPoint(F)) {
327328
const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,10 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
709709
// TODO: change the implementation once opaque pointers are supported
710710
// in the SPIR-V specification.
711711
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
712-
auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
712+
// Get access to information about available extensions
713+
const SPIRVSubtarget *ST =
714+
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
715+
auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST);
713716
// Null pointer means we have a loop in type definitions, make and
714717
// return corresponding OpTypeForwardPointer.
715718
if (SpvElementType == nullptr) {

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,10 @@ def OpGenericCastToPtrExplicit : Op<123, (outs ID:$r), (ins TYPE:$t, ID:$p, Stor
430430
"$r = OpGenericCastToPtrExplicit $t $p $s">;
431431
def OpBitcast : UnOp<"OpBitcast", 124>;
432432

433+
// SPV_INTEL_usm_storage_classes
434+
def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
435+
def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
436+
433437
// 3.42.12 Composite Instructions
434438

435439
def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -828,8 +828,18 @@ static bool isGenericCastablePtr(SPIRV::StorageClass::StorageClass SC) {
828828
}
829829
}
830830

831+
static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
832+
switch (SC) {
833+
case SPIRV::StorageClass::DeviceOnlyINTEL:
834+
case SPIRV::StorageClass::HostOnlyINTEL:
835+
return true;
836+
default:
837+
return false;
838+
}
839+
}
840+
831841
// In SPIR-V address space casting can only happen to and from the Generic
832-
// storage class. We can also only case Workgroup, CrossWorkgroup, or Function
842+
// storage class. We can also only cast Workgroup, CrossWorkgroup, or Function
833843
// pointers to and from Generic pointers. As such, we can convert e.g. from
834844
// Workgroup to Function by going via a Generic pointer as an intermediary. All
835845
// other combinations can only be done by a bitcast, and are probably not safe.
@@ -862,13 +872,17 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
862872
SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
863873
SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);
864874

865-
// Casting from an eligable pointer to Generic.
875+
// don't generate a cast between identical storage classes
876+
if (SrcSC == DstSC)
877+
return true;
878+
879+
// Casting from an eligible pointer to Generic.
866880
if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC))
867881
return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric);
868-
// Casting from Generic to an eligable pointer.
882+
// Casting from Generic to an eligible pointer.
869883
if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC))
870884
return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
871-
// Casting between 2 eligable pointers using Generic as an intermediary.
885+
// Casting between 2 eligible pointers using Generic as an intermediary.
872886
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
873887
Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
874888
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
@@ -886,6 +900,16 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
886900
.addUse(Tmp)
887901
.constrainAllUses(TII, TRI, RBI);
888902
}
903+
904+
// Check if instructions from the SPV_INTEL_usm_storage_classes extension may
905+
// be applied
906+
if (isUSMStorageClass(SrcSC) && DstSC == SPIRV::StorageClass::CrossWorkgroup)
907+
return selectUnOp(ResVReg, ResType, I,
908+
SPIRV::OpPtrCastToCrossWorkgroupINTEL);
909+
if (SrcSC == SPIRV::StorageClass::CrossWorkgroup && isUSMStorageClass(DstSC))
910+
return selectUnOp(ResVReg, ResType, I,
911+
SPIRV::OpCrossWorkgroupCastToPtrINTEL);
912+
889913
// TODO Should this case just be disallowed completely?
890914
// We're casting 2 other arbitrary address spaces, so have to bitcast.
891915
return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
@@ -1545,7 +1569,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
15451569
}
15461570
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
15471571
PointerBaseType, I, TII,
1548-
addressSpaceToStorageClass(GV->getAddressSpace()));
1572+
addressSpaceToStorageClass(GV->getAddressSpace(), STI));
15491573

15501574
std::string GlobalIdent;
15511575
if (!GV->hasName()) {
@@ -1618,7 +1642,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
16181642

16191643
unsigned AddrSpace = GV->getAddressSpace();
16201644
SPIRV::StorageClass::StorageClass Storage =
1621-
addressSpaceToStorageClass(AddrSpace);
1645+
addressSpaceToStorageClass(AddrSpace, STI);
16221646
bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage &&
16231647
Storage != SPIRV::StorageClass::Function;
16241648
SPIRV::LinkageType::LinkageType LnkType =

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
102102
const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103103
const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104104
const LLT p4 = LLT::pointer(4, PSize); // Generic
105-
const LLT p5 = LLT::pointer(5, PSize); // Input
105+
const LLT p5 =
106+
LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
107+
const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
106108

107109
// TODO: remove copy-pasting here by using concatenation in some way.
108110
auto allPtrsScalarsAndVectors = {
109-
p0, p1, p2, p3, p4, p5, s1, s8, s16,
110-
s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111-
v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
112-
v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
111+
p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
112+
s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
113+
v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
114+
v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113115

114116
auto allScalarsAndVectors = {
115117
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
@@ -133,8 +135,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
133135

134136
auto allFloatAndIntScalars = allIntScalars;
135137

136-
auto allPtrs = {p0, p1, p2, p3, p4, p5};
137-
auto allWritablePtrs = {p0, p1, p3, p4};
138+
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
139+
auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
138140

139141
for (auto Opc : TypeFoldingSupportingOpcs)
140142
getActionDefinitionsBuilder(Opc).custom();

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,13 @@ void addInstrRequirements(const MachineInstr &MI,
10631063
Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
10641064
}
10651065
break;
1066+
case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1067+
case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1068+
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1069+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1070+
Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1071+
}
1072+
break;
10661073
case SPIRV::OpConstantFunctionPointerINTEL:
10671074
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
10681075
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
122122

123123
static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
124124
MachineIRBuilder MIB) {
125+
// Get access to information about available extensions
126+
const SPIRVSubtarget *ST =
127+
static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
125128
SmallVector<MachineInstr *, 10> ToErase;
126129
for (MachineBasicBlock &MBB : MF) {
127130
for (MachineInstr &MI : MBB) {
@@ -141,7 +144,7 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
141144
getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
142145
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
143146
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
144-
addressSpaceToStorageClass(MI.getOperand(4).getImm()));
147+
addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
145148

146149
// If the bitcast would be redundant, replace all uses with the source
147150
// register.
@@ -250,6 +253,10 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
250253

251254
static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
252255
MachineIRBuilder MIB) {
256+
// Get access to information about available extensions
257+
const SPIRVSubtarget *ST =
258+
static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
259+
253260
MachineRegisterInfo &MRI = MF.getRegInfo();
254261
SmallVector<MachineInstr *, 10> ToErase;
255262

@@ -269,7 +276,7 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
269276
getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
270277
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
271278
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
272-
addressSpaceToStorageClass(MI.getOperand(3).getImm()));
279+
addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
273280
MachineInstr *Def = MRI.getVRegDef(Reg);
274281
assert(Def && "Expecting an instruction that defines the register");
275282
insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,

llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ cl::list<SPIRV::Extension::Extension> Extensions(
4949
clEnumValN(SPIRV::Extension::SPV_INTEL_optnone, "SPV_INTEL_optnone",
5050
"Adds OptNoneINTEL value for Function Control mask that "
5151
"indicates a request to not optimize the function."),
52+
clEnumValN(SPIRV::Extension::SPV_INTEL_usm_storage_classes,
53+
"SPV_INTEL_usm_storage_classes",
54+
"Introduces two new storage classes that are sub classes of "
55+
"the CrossWorkgroup storage class "
56+
"that provides additional information that can enable "
57+
"optimization."),
5258
clEnumValN(SPIRV::Extension::SPV_INTEL_subgroups, "SPV_INTEL_subgroups",
5359
"Allows work items in a subgroup to share data without the "
5460
"use of local memory and work group barriers, and to "

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom
463463
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
464464
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
465465
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
466+
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
466467

467468
//===----------------------------------------------------------------------===//
468469
// Multiclass used to define SourceLanguage enum values and at the same time
@@ -700,6 +701,8 @@ defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>;
700701
defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>;
701702
defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>;
702703
defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>;
704+
defm DeviceOnlyINTEL : StorageClassOperand<5936, [USMStorageClassesINTEL]>;
705+
defm HostOnlyINTEL : StorageClassOperand<5937, [USMStorageClassesINTEL]>;
703706

704707
//===----------------------------------------------------------------------===//
705708
// Multiclass used to define Dim enum values and at the same time

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "MCTargetDesc/SPIRVBaseInfo.h"
1515
#include "SPIRV.h"
1616
#include "SPIRVInstrInfo.h"
17+
#include "SPIRVSubtarget.h"
1718
#include "llvm/ADT/StringRef.h"
1819
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
1920
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
@@ -146,15 +147,19 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
146147
return 3;
147148
case SPIRV::StorageClass::Generic:
148149
return 4;
150+
case SPIRV::StorageClass::DeviceOnlyINTEL:
151+
return 5;
152+
case SPIRV::StorageClass::HostOnlyINTEL:
153+
return 6;
149154
case SPIRV::StorageClass::Input:
150155
return 7;
151156
default:
152-
llvm_unreachable("Unable to get address space id");
157+
report_fatal_error("Unable to get address space id");
153158
}
154159
}
155160

156161
SPIRV::StorageClass::StorageClass
157-
addressSpaceToStorageClass(unsigned AddrSpace) {
162+
addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
158163
switch (AddrSpace) {
159164
case 0:
160165
return SPIRV::StorageClass::Function;
@@ -166,10 +171,18 @@ addressSpaceToStorageClass(unsigned AddrSpace) {
166171
return SPIRV::StorageClass::Workgroup;
167172
case 4:
168173
return SPIRV::StorageClass::Generic;
174+
case 5:
175+
return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
176+
? SPIRV::StorageClass::DeviceOnlyINTEL
177+
: SPIRV::StorageClass::CrossWorkgroup;
178+
case 6:
179+
return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
180+
? SPIRV::StorageClass::HostOnlyINTEL
181+
: SPIRV::StorageClass::CrossWorkgroup;
169182
case 7:
170183
return SPIRV::StorageClass::Input;
171184
default:
172-
llvm_unreachable("Unknown address space");
185+
report_fatal_error("Unknown address space");
173186
}
174187
}
175188

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class MachineRegisterInfo;
2727
class Register;
2828
class StringRef;
2929
class SPIRVInstrInfo;
30+
class SPIRVSubtarget;
3031

3132
// Add the given string as a series of integer operand, inserting null
3233
// terminators and padding to make sure the operands all have 32-bit
@@ -62,7 +63,7 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC);
6263

6364
// Convert an LLVM IR address space to a SPIR-V storage class.
6465
SPIRV::StorageClass::StorageClass
65-
addressSpaceToStorageClass(unsigned AddrSpace);
66+
addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI);
6667

6768
SPIRV::MemorySemantics::MemorySemantics
6869
getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC);

0 commit comments

Comments
 (0)