Skip to content

Commit 3e79c7f

Browse files
[SPIR-V] Implement OpSpecConstantOp with ptr-cast operation (#109979)
This PR reworks implementation of OpSpecConstantOp with ptr-cast operation (PtrCastToGeneric, GenericCastToPtr). Previous implementation didn't take into account a lot of use cases, including multiple inclusion of pointers, reference to a pointer from OpName, etc. A reproducer is attached as a new test case. This PR also fixes wrong type inference for IR patterns which generate new virtual registers without SPIRV type. Previous implementation assumed always that result has the same address space as a source that is not the fact, and, for example, led to impossibility to emit a ptr-cast operation in the reproducer, because wrong type inference rendered source and destination with the same address space, eliminating translation of G_ADDRSPACE_CAST.
1 parent 8bc8b84 commit 3e79c7f

File tree

10 files changed

+275
-55
lines changed

10 files changed

+275
-55
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,11 @@ SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
11281128
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
11291129
assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
11301130
Type->getOperand(1).isImm() && "Pointer type is expected");
1131+
return getPointerStorageClass(Type);
1132+
}
1133+
1134+
SPIRV::StorageClass::StorageClass
1135+
SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
11311136
return static_cast<SPIRV::StorageClass::StorageClass>(
11321137
Type->getOperand(1).getImm());
11331138
}

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ class SPIRVGlobalRegistry {
405405

406406
// Gets the storage class of the pointer type assigned to this vreg.
407407
SPIRV::StorageClass::StorageClass getPointerStorageClass(Register VReg) const;
408+
SPIRV::StorageClass::StorageClass
409+
getPointerStorageClass(const SPIRVType *Type) const;
408410

409411
// Return the number of bits SPIR-V pointers and size_t variables require.
410412
unsigned getPointerSize() const { return PointerSize; }

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
249249

250250
bool selectUnmergeValues(MachineInstr &I) const;
251251

252+
// Utilities
252253
Register buildI32Constant(uint32_t Val, MachineInstr &I,
253254
const SPIRVType *ResType = nullptr) const;
254255

@@ -260,6 +261,14 @@ class SPIRVInstructionSelector : public InstructionSelector {
260261

261262
bool wrapIntoSpecConstantOp(MachineInstr &I,
262263
SmallVector<Register> &CompositeArgs) const;
264+
265+
Register getUcharPtrTypeReg(MachineInstr &I,
266+
SPIRV::StorageClass::StorageClass SC) const;
267+
MachineInstrBuilder buildSpecConstantOp(MachineInstr &I, Register Dest,
268+
Register Src, Register DestType,
269+
uint32_t Opcode) const;
270+
MachineInstrBuilder buildConstGenericPtr(MachineInstr &I, Register SrcPtr,
271+
SPIRVType *SrcPtrTy) const;
263272
};
264273

265274
} // end anonymous namespace
@@ -1244,6 +1253,58 @@ static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
12441253
}
12451254
}
12461255

1256+
// Returns true ResVReg is referred only from global vars and OpName's.
1257+
static bool isASCastInGVar(MachineRegisterInfo *MRI, Register ResVReg) {
1258+
bool IsGRef = false;
1259+
bool IsAllowedRefs =
1260+
std::all_of(MRI->use_instr_begin(ResVReg), MRI->use_instr_end(),
1261+
[&IsGRef](auto const &It) {
1262+
unsigned Opcode = It.getOpcode();
1263+
if (Opcode == SPIRV::OpConstantComposite ||
1264+
Opcode == SPIRV::OpVariable ||
1265+
isSpvIntrinsic(It, Intrinsic::spv_init_global))
1266+
return IsGRef = true;
1267+
return Opcode == SPIRV::OpName;
1268+
});
1269+
return IsAllowedRefs && IsGRef;
1270+
}
1271+
1272+
Register SPIRVInstructionSelector::getUcharPtrTypeReg(
1273+
MachineInstr &I, SPIRV::StorageClass::StorageClass SC) const {
1274+
return GR.getSPIRVTypeID(GR.getOrCreateSPIRVPointerType(
1275+
GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, SC));
1276+
}
1277+
1278+
MachineInstrBuilder
1279+
SPIRVInstructionSelector::buildSpecConstantOp(MachineInstr &I, Register Dest,
1280+
Register Src, Register DestType,
1281+
uint32_t Opcode) const {
1282+
return BuildMI(*I.getParent(), I, I.getDebugLoc(),
1283+
TII.get(SPIRV::OpSpecConstantOp))
1284+
.addDef(Dest)
1285+
.addUse(DestType)
1286+
.addImm(Opcode)
1287+
.addUse(Src);
1288+
}
1289+
1290+
MachineInstrBuilder
1291+
SPIRVInstructionSelector::buildConstGenericPtr(MachineInstr &I, Register SrcPtr,
1292+
SPIRVType *SrcPtrTy) const {
1293+
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
1294+
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
1295+
Register Tmp = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
1296+
MRI->setType(Tmp, LLT::pointer(storageClassToAddressSpace(
1297+
SPIRV::StorageClass::Generic),
1298+
GR.getPointerSize()));
1299+
MachineFunction *MF = I.getParent()->getParent();
1300+
GR.assignSPIRVTypeToVReg(GenericPtrTy, Tmp, *MF);
1301+
MachineInstrBuilder MIB = buildSpecConstantOp(
1302+
I, Tmp, SrcPtr, GR.getSPIRVTypeID(GenericPtrTy),
1303+
static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric));
1304+
GR.add(MIB.getInstr(), MF, Tmp);
1305+
return MIB;
1306+
}
1307+
12471308
// In SPIR-V address space casting can only happen to and from the Generic
12481309
// storage class. We can also only cast Workgroup, CrossWorkgroup, or Function
12491310
// pointers to and from Generic pointers. As such, we can convert e.g. from
@@ -1252,36 +1313,57 @@ static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
12521313
bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
12531314
const SPIRVType *ResType,
12541315
MachineInstr &I) const {
1255-
// If the AddrSpaceCast user is single and in OpConstantComposite or
1256-
// OpVariable, we should select OpSpecConstantOp.
1257-
auto UIs = MRI->use_instructions(ResVReg);
1258-
if (!UIs.empty() && ++UIs.begin() == UIs.end() &&
1259-
(UIs.begin()->getOpcode() == SPIRV::OpConstantComposite ||
1260-
UIs.begin()->getOpcode() == SPIRV::OpVariable ||
1261-
isSpvIntrinsic(*UIs.begin(), Intrinsic::spv_init_global))) {
1262-
Register NewReg = I.getOperand(1).getReg();
1263-
MachineBasicBlock &BB = *I.getParent();
1264-
SPIRVType *SpvBaseTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
1265-
ResType = GR.getOrCreateSPIRVPointerType(SpvBaseTy, I, TII,
1266-
SPIRV::StorageClass::Generic);
1267-
bool Result =
1268-
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSpecConstantOp))
1269-
.addDef(ResVReg)
1270-
.addUse(GR.getSPIRVTypeID(ResType))
1271-
.addImm(static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric))
1272-
.addUse(NewReg)
1273-
.constrainAllUses(TII, TRI, RBI);
1274-
return Result;
1275-
}
1316+
MachineBasicBlock &BB = *I.getParent();
1317+
const DebugLoc &DL = I.getDebugLoc();
1318+
12761319
Register SrcPtr = I.getOperand(1).getReg();
12771320
SPIRVType *SrcPtrTy = GR.getSPIRVTypeForVReg(SrcPtr);
1278-
SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
1279-
SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);
1321+
1322+
// don't generate a cast for a null that may be represented by OpTypeInt
1323+
if (SrcPtrTy->getOpcode() != SPIRV::OpTypePointer ||
1324+
ResType->getOpcode() != SPIRV::OpTypePointer)
1325+
return BuildMI(BB, I, DL, TII.get(TargetOpcode::COPY))
1326+
.addDef(ResVReg)
1327+
.addUse(SrcPtr)
1328+
.constrainAllUses(TII, TRI, RBI);
1329+
1330+
SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtrTy);
1331+
SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResType);
1332+
1333+
if (isASCastInGVar(MRI, ResVReg)) {
1334+
// AddrSpaceCast uses within OpVariable and OpConstantComposite instructions
1335+
// are expressed by OpSpecConstantOp with an Opcode.
1336+
// TODO: maybe insert a check whether the Kernel capability was declared and
1337+
// so PtrCastToGeneric/GenericCastToPtr are available.
1338+
unsigned SpecOpcode =
1339+
DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC)
1340+
? static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric)
1341+
: (SrcSC == SPIRV::StorageClass::Generic &&
1342+
isGenericCastablePtr(DstSC)
1343+
? static_cast<uint32_t>(SPIRV::Opcode::GenericCastToPtr)
1344+
: 0);
1345+
// TODO: OpConstantComposite expects i8*, so we are forced to forget a
1346+
// correct value of ResType and use general i8* instead. Maybe this should
1347+
// be addressed in the emit-intrinsic step to infer a correct
1348+
// OpConstantComposite type.
1349+
if (SpecOpcode) {
1350+
return buildSpecConstantOp(I, ResVReg, SrcPtr,
1351+
getUcharPtrTypeReg(I, DstSC), SpecOpcode)
1352+
.constrainAllUses(TII, TRI, RBI);
1353+
} else if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
1354+
MachineInstrBuilder MIB = buildConstGenericPtr(I, SrcPtr, SrcPtrTy);
1355+
return MIB.constrainAllUses(TII, TRI, RBI) &&
1356+
buildSpecConstantOp(
1357+
I, ResVReg, MIB->getOperand(0).getReg(),
1358+
getUcharPtrTypeReg(I, DstSC),
1359+
static_cast<uint32_t>(SPIRV::Opcode::GenericCastToPtr))
1360+
.constrainAllUses(TII, TRI, RBI);
1361+
}
1362+
}
12801363

12811364
// don't generate a cast between identical storage classes
12821365
if (SrcSC == DstSC)
1283-
return BuildMI(*I.getParent(), I, I.getDebugLoc(),
1284-
TII.get(TargetOpcode::COPY))
1366+
return BuildMI(BB, I, DL, TII.get(TargetOpcode::COPY))
12851367
.addDef(ResVReg)
12861368
.addUse(SrcPtr)
12871369
.constrainAllUses(TII, TRI, RBI);
@@ -1297,8 +1379,6 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
12971379
Register Tmp = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
12981380
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
12991381
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
1300-
MachineBasicBlock &BB = *I.getParent();
1301-
const DebugLoc &DL = I.getDebugLoc();
13021382
bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
13031383
.addDef(Tmp)
13041384
.addUse(GR.getSPIRVTypeID(GenericPtrTy))

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,21 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
294294
default:
295295
break;
296296
}
297-
if (SpvType)
297+
if (SpvType) {
298+
// check if the address space needs correction
299+
LLT RegType = MRI.getType(Reg);
300+
if (SpvType->getOpcode() == SPIRV::OpTypePointer &&
301+
RegType.isPointer() &&
302+
storageClassToAddressSpace(GR->getPointerStorageClass(SpvType)) !=
303+
RegType.getAddressSpace()) {
304+
const SPIRVSubtarget &ST =
305+
MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
306+
SpvType = GR->getOrCreateSPIRVPointerType(
307+
GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
308+
addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
309+
}
298310
GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
311+
}
299312
if (!MRI.getRegClassOrNull(Reg))
300313
MRI.setRegClass(Reg, SpvType ? GR->getRegClass(SpvType)
301314
: &SPIRV::iIDRegClass);
@@ -519,6 +532,14 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
519532
? MI.getOperand(1).getCImm()->getType()
520533
: TargetExtIt->second;
521534
const ConstantInt *OpCI = MI.getOperand(1).getCImm();
535+
// TODO: we may wish to analyze here if OpCI is zero and LLT RegType =
536+
// MRI.getType(Reg); RegType.isPointer() is true, so that we observe
537+
// at this point not i64/i32 constant but null pointer in the
538+
// corresponding address space of RegType.getAddressSpace(). This may
539+
// help to successfully validate the case when a OpConstantComposite's
540+
// constituent has type that does not match Result Type of
541+
// OpConstantComposite (see, for example,
542+
// pointers/PtrCast-null-in-OpSpecConstantOp.ll).
522543
Register PrimaryReg = GR->find(OpCI, &MF);
523544
if (!PrimaryReg.isValid()) {
524545
GR->add(OpCI, &MF, Reg);

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,7 @@ multiclass OpcodeOperand<bits<32> value> {
16311631
defm InBoundsAccessChain : OpcodeOperand<66>;
16321632
defm InBoundsPtrAccessChain : OpcodeOperand<70>;
16331633
defm PtrCastToGeneric : OpcodeOperand<121>;
1634+
defm GenericCastToPtr : OpcodeOperand<122>;
16341635
defm Bitcast : OpcodeOperand<124>;
16351636
defm ConvertPtrToU : OpcodeOperand<117>;
16361637
defm ConvertUToPtr : OpcodeOperand<120>;

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
4545
}
4646

4747
// Get length including padding and null terminator.
48-
static size_t getPaddedLen(const StringRef &Str) { return Str.size() + 4 & ~3; }
48+
static size_t getPaddedLen(const StringRef &Str) {
49+
return (Str.size() + 4) & ~3;
50+
}
4951

5052
void addStringImm(const StringRef &Str, MCInst &Inst) {
5153
const size_t PaddedLen = getPaddedLen(Str);
@@ -160,31 +162,6 @@ void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
160162
}
161163
}
162164

163-
// TODO: maybe the following two functions should be handled in the subtarget
164-
// to allow for different OpenCL vs Vulkan handling.
165-
unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
166-
switch (SC) {
167-
case SPIRV::StorageClass::Function:
168-
return 0;
169-
case SPIRV::StorageClass::CrossWorkgroup:
170-
return 1;
171-
case SPIRV::StorageClass::UniformConstant:
172-
return 2;
173-
case SPIRV::StorageClass::Workgroup:
174-
return 3;
175-
case SPIRV::StorageClass::Generic:
176-
return 4;
177-
case SPIRV::StorageClass::DeviceOnlyINTEL:
178-
return 5;
179-
case SPIRV::StorageClass::HostOnlyINTEL:
180-
return 6;
181-
case SPIRV::StorageClass::Input:
182-
return 7;
183-
default:
184-
report_fatal_error("Unable to get address space id");
185-
}
186-
}
187-
188165
SPIRV::StorageClass::StorageClass
189166
addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
190167
switch (AddrSpace) {

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,31 @@ void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
134134
const MDNode *GVarMD);
135135

136136
// Convert a SPIR-V storage class to the corresponding LLVM IR address space.
137-
unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC);
137+
// TODO: maybe the following two functions should be handled in the subtarget
138+
// to allow for different OpenCL vs Vulkan handling.
139+
constexpr unsigned
140+
storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
141+
switch (SC) {
142+
case SPIRV::StorageClass::Function:
143+
return 0;
144+
case SPIRV::StorageClass::CrossWorkgroup:
145+
return 1;
146+
case SPIRV::StorageClass::UniformConstant:
147+
return 2;
148+
case SPIRV::StorageClass::Workgroup:
149+
return 3;
150+
case SPIRV::StorageClass::Generic:
151+
return 4;
152+
case SPIRV::StorageClass::DeviceOnlyINTEL:
153+
return 5;
154+
case SPIRV::StorageClass::HostOnlyINTEL:
155+
return 6;
156+
case SPIRV::StorageClass::Input:
157+
return 7;
158+
default:
159+
report_fatal_error("Unable to get address space id");
160+
}
161+
}
138162

139163
// Convert an LLVM IR address space to a SPIR-V storage class.
140164
SPIRV::StorageClass::StorageClass
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; The goal of this test case is to check that cases covered by pointers/PtrCast-in-OpSpecConstantOp.ll and
2+
; pointers/PtrCast-null-in-OpSpecConstantOp.ll (that is OpSpecConstantOp with ptr-cast operation) correctly
3+
; work also for function pointers.
4+
5+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - --spirv-ext=+SPV_INTEL_function_pointers | FileCheck %s
6+
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
7+
8+
; Running with -verify-machineinstrs would lead to "Reading virtual register without a def"
9+
; error, because OpConstantFunctionPointerINTEL forward-refers to a function definition.
10+
11+
; CHECK-COUNT-3: %[[#]] = OpSpecConstantOp %[[#]] 121 %[[#]]
12+
; CHECK-COUNT-3: OpPtrCastToGeneric
13+
14+
@G1 = addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @foo to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr @bar to ptr addrspace(4))] }
15+
@G2 = addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) addrspacecast (ptr null to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr @bar to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr @foo to ptr addrspace(4))] }
16+
17+
define void @foo(ptr addrspace(4) %p) {
18+
entry:
19+
%r1 = addrspacecast ptr @foo to ptr addrspace(4)
20+
%r2 = addrspacecast ptr null to ptr addrspace(4)
21+
ret void
22+
}
23+
24+
define void @bar(ptr addrspace(4) %p) {
25+
entry:
26+
%r1 = addrspacecast ptr @bar to ptr addrspace(4)
27+
ret void
28+
}

0 commit comments

Comments
 (0)