Skip to content

Commit 3212555

Browse files
authored
[SPIRV] Reapply explicit layout PRs (#138867)
The asan failure was fixed by #138695, but another failure was introduced in the meantime. The cause for the other failure has been fixed. I will reapply the two PRs. Reapply "[SPIRV] Add explicit layout (#135789)" This reverts commit 0fb5720. Reapply "[SPIRV] Fix asan failure (#138695)" This reverts commit df90ab9.
1 parent e0a951f commit 3212555

File tree

7 files changed

+523
-124
lines changed

7 files changed

+523
-124
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 188 additions & 105 deletions
Large diffs are not rendered by default.

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
9090
// Add a new OpTypeXXX instruction without checking for duplicates.
9191
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
9292
SPIRV::AccessQualifier::AccessQualifier AQ,
93-
bool EmitIR);
93+
bool ExplicitLayoutRequired, bool EmitIR);
9494
SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
9595
SPIRV::AccessQualifier::AccessQualifier accessQual,
96-
bool EmitIR);
96+
bool ExplicitLayoutRequired, bool EmitIR);
9797
SPIRVType *
9898
restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
9999
SPIRV::AccessQualifier::AccessQualifier AccessQual,
100-
bool EmitIR);
100+
bool ExplicitLayoutRequired, bool EmitIR);
101101

102102
// Internal function creating the an OpType at the correct position in the
103103
// function by tweaking the passed "MIRBuilder" insertion point and restoring
@@ -298,10 +298,19 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
298298
// EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes)
299299
// because this method may be called from InstructionSelector and we don't
300300
// want to emit extra IR instructions there.
301+
SPIRVType *getOrCreateSPIRVType(const Type *Type, MachineInstr &I,
302+
SPIRV::AccessQualifier::AccessQualifier AQ,
303+
bool EmitIR) {
304+
MachineIRBuilder MIRBuilder(I);
305+
return getOrCreateSPIRVType(Type, MIRBuilder, AQ, EmitIR);
306+
}
307+
301308
SPIRVType *getOrCreateSPIRVType(const Type *Type,
302309
MachineIRBuilder &MIRBuilder,
303310
SPIRV::AccessQualifier::AccessQualifier AQ,
304-
bool EmitIR);
311+
bool EmitIR) {
312+
return getOrCreateSPIRVType(Type, MIRBuilder, AQ, false, EmitIR);
313+
}
305314

306315
const Type *getTypeForSPIRVType(const SPIRVType *Ty) const {
307316
auto Res = SPIRVToLLVMType.find(Ty);
@@ -364,6 +373,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
364373
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
365374
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
366375

376+
// Returns true if `Type` is a resource type. This could be an image type
377+
// or a struct for a buffer decorated with the block decoration.
378+
bool isResourceType(SPIRVType *Type) const;
379+
367380
// Return number of elements in a vector if the argument is associated with
368381
// a vector type. Return 1 for a scalar type, and 0 for a missing type.
369382
unsigned getScalarOrVectorComponentCount(Register VReg) const;
@@ -414,6 +427,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
414427
const Type *adjustIntTypeByWidth(const Type *Ty) const;
415428
unsigned adjustOpTypeIntWidth(unsigned Width) const;
416429

430+
SPIRVType *getOrCreateSPIRVType(const Type *Type,
431+
MachineIRBuilder &MIRBuilder,
432+
SPIRV::AccessQualifier::AccessQualifier AQ,
433+
bool ExplicitLayoutRequired, bool EmitIR);
434+
417435
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
418436
bool IsSigned = false);
419437

@@ -425,14 +443,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
425443
MachineIRBuilder &MIRBuilder);
426444

427445
SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
428-
MachineIRBuilder &MIRBuilder, bool EmitIR);
446+
MachineIRBuilder &MIRBuilder,
447+
bool ExplicitLayoutRequired, bool EmitIR);
429448

430449
SPIRVType *getOpTypeOpaque(const StructType *Ty,
431450
MachineIRBuilder &MIRBuilder);
432451

433452
SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
434453
SPIRV::AccessQualifier::AccessQualifier AccQual,
435-
bool EmitIR);
454+
bool ExplicitLayoutRequired, bool EmitIR);
436455

437456
SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
438457
SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
@@ -475,6 +494,12 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
475494
MachineIRBuilder &MIRBuilder,
476495
SPIRV::StorageClass::StorageClass SC);
477496

497+
void addStructOffsetDecorations(Register Reg, StructType *Ty,
498+
MachineIRBuilder &MIRBuilder);
499+
void addArrayStrideDecorations(Register Reg, Type *ElementType,
500+
MachineIRBuilder &MIRBuilder);
501+
bool hasBlockDecoration(SPIRVType *Type) const;
502+
478503
public:
479504
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
480505
SPIRVType *SpvType, bool EmitIR,
@@ -545,9 +570,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
545570
SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
546571
unsigned NumElements, MachineInstr &I,
547572
const SPIRVInstrInfo &TII);
548-
SPIRVType *getOrCreateSPIRVArrayType(SPIRVType *BaseType,
549-
unsigned NumElements, MachineInstr &I,
550-
const SPIRVInstrInfo &TII);
551573

552574
// Returns a pointer to a SPIR-V pointer type with the given base type and
553575
// storage class. The base type will be translated to a SPIR-V type, and the

llvm/lib/Target/SPIRV/SPIRVIRMapping.h

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum SpecialTypeKind {
6666
STK_Value,
6767
STK_MachineInstr,
6868
STK_VkBuffer,
69+
STK_ExplictLayoutType,
6970
STK_Last = -1
7071
};
7172

@@ -150,6 +151,11 @@ inline IRHandle irhandle_vkbuffer(const Type *ElementType,
150151
SpecialTypeKind::STK_VkBuffer);
151152
}
152153

154+
inline IRHandle irhandle_explict_layout_type(const Type *Ty) {
155+
const Type *WrpTy = unifyPtrType(Ty);
156+
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_ExplictLayoutType);
157+
}
158+
153159
inline IRHandle handle(const Type *Ty) {
154160
const Type *WrpTy = unifyPtrType(Ty);
155161
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_Type);
@@ -163,6 +169,10 @@ inline IRHandle handle(const MachineInstr *KeyMI) {
163169
return irhandle_ptr(KeyMI, SPIRV::to_hash(KeyMI), STK_MachineInstr);
164170
}
165171

172+
inline bool type_has_layout_decoration(const Type *T) {
173+
return (isa<StructType>(T) || isa<ArrayType>(T));
174+
}
175+
166176
} // namespace SPIRV
167177

168178
// Bi-directional mappings between LLVM entities and (v-reg, machine function)
@@ -238,14 +248,49 @@ class SPIRVIRMapping {
238248
return findMI(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF);
239249
}
240250

241-
template <typename T> bool add(const T *Obj, const MachineInstr *MI) {
251+
bool add(const Value *V, const MachineInstr *MI) {
252+
return add(SPIRV::handle(V), MI);
253+
}
254+
255+
bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) {
256+
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) {
257+
return add(SPIRV::irhandle_explict_layout_type(T), MI);
258+
}
259+
return add(SPIRV::handle(T), MI);
260+
}
261+
262+
bool add(const MachineInstr *Obj, const MachineInstr *MI) {
242263
return add(SPIRV::handle(Obj), MI);
243264
}
244-
template <typename T> Register find(const T *Obj, const MachineFunction *MF) {
245-
return find(SPIRV::handle(Obj), MF);
265+
266+
Register find(const Value *V, const MachineFunction *MF) {
267+
return find(SPIRV::handle(V), MF);
268+
}
269+
270+
Register find(const Type *T, bool RequiresExplicitLayout,
271+
const MachineFunction *MF) {
272+
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
273+
return find(SPIRV::irhandle_explict_layout_type(T), MF);
274+
return find(SPIRV::handle(T), MF);
275+
}
276+
277+
Register find(const MachineInstr *MI, const MachineFunction *MF) {
278+
return find(SPIRV::handle(MI), MF);
279+
}
280+
281+
const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) {
282+
return findMI(SPIRV::handle(Obj), MF);
283+
}
284+
285+
const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout,
286+
const MachineFunction *MF) {
287+
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
288+
return findMI(SPIRV::irhandle_explict_layout_type(T), MF);
289+
return findMI(SPIRV::handle(T), MF);
246290
}
247-
template <typename T>
248-
const MachineInstr *findMI(const T *Obj, const MachineFunction *MF) {
291+
292+
const MachineInstr *findMI(const MachineInstr *Obj,
293+
const MachineFunction *MF) {
249294
return findMI(SPIRV::handle(Obj), MF);
250295
}
251296
};

llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,42 @@
2525

2626
using namespace llvm;
2727

28+
// Returns true of the types logically match, as defined in
29+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
30+
static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2,
31+
SPIRVGlobalRegistry &GR) {
32+
if (Ty1->getOpcode() != Ty2->getOpcode())
33+
return false;
34+
35+
if (Ty1->getNumOperands() != Ty2->getNumOperands())
36+
return false;
37+
38+
if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
39+
// Array must have the same size.
40+
if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg())
41+
return false;
42+
43+
SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg());
44+
SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg());
45+
return ElemType1 == ElemType2 ||
46+
typesLogicallyMatch(ElemType1, ElemType2, GR);
47+
}
48+
49+
if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
50+
for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
51+
SPIRVType *ElemType1 =
52+
GR.getSPIRVTypeForVReg(Ty1->getOperand(I).getReg());
53+
SPIRVType *ElemType2 =
54+
GR.getSPIRVTypeForVReg(Ty2->getOperand(I).getReg());
55+
if (ElemType1 != ElemType2 &&
56+
!typesLogicallyMatch(ElemType1, ElemType2, GR))
57+
return false;
58+
}
59+
return true;
60+
}
61+
return false;
62+
}
63+
2864
unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
2965
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
3066
// This code avoids CallLowering fail inside getVectorTypeBreakdown
@@ -374,6 +410,9 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
374410
// implies that %Op is a pointer to <ResType>
375411
case SPIRV::OpLoad:
376412
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
413+
if (enforcePtrTypeCompatibility(MI, 2, 0))
414+
break;
415+
377416
validatePtrTypes(STI, MRI, GR, MI, 2,
378417
GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
379418
break;
@@ -531,3 +570,58 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
531570
ProcessedMF.insert(&MF);
532571
TargetLowering::finalizeLowering(MF);
533572
}
573+
574+
// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
575+
// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
576+
// match or if the instruction was modified to make them match.
577+
bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
578+
MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
579+
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
580+
SPIRVType *PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg());
581+
SPIRVType *PointeeType = GR.getPointeeType(PtrType);
582+
SPIRVType *OpType = GR.getResultType(I.getOperand(OpIdx).getReg());
583+
584+
if (PointeeType == OpType)
585+
return true;
586+
587+
if (typesLogicallyMatch(PointeeType, OpType, GR)) {
588+
// Apply OpCopyLogical to OpIdx.
589+
if (I.getOperand(OpIdx).isDef() &&
590+
insertLogicalCopyOnResult(I, PointeeType)) {
591+
return true;
592+
}
593+
594+
llvm_unreachable("Unable to add OpCopyLogical yet.");
595+
return false;
596+
}
597+
598+
return false;
599+
}
600+
601+
bool SPIRVTargetLowering::insertLogicalCopyOnResult(
602+
MachineInstr &I, SPIRVType *NewResultType) const {
603+
MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
604+
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
605+
606+
Register NewResultReg =
607+
createVirtualRegister(NewResultType, &GR, MRI, *I.getMF());
608+
Register NewTypeReg = GR.getSPIRVTypeID(NewResultType);
609+
610+
assert(std::distance(I.defs().begin(), I.defs().end()) == 1 &&
611+
"Expected only one def");
612+
MachineOperand &OldResult = *I.defs().begin();
613+
Register OldResultReg = OldResult.getReg();
614+
MachineOperand &OldType = *I.uses().begin();
615+
Register OldTypeReg = OldType.getReg();
616+
617+
OldResult.setReg(NewResultReg);
618+
OldType.setReg(NewTypeReg);
619+
620+
MachineIRBuilder MIB(*I.getNextNode());
621+
return MIB.buildInstr(SPIRV::OpCopyLogical)
622+
.addDef(OldResultReg)
623+
.addUse(OldTypeReg)
624+
.addUse(NewResultReg)
625+
.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
626+
*STI.getRegBankInfo());
627+
}

llvm/lib/Target/SPIRV/SPIRVISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ class SPIRVTargetLowering : public TargetLowering {
7171
EVT ConditionVT) const override {
7272
return ConditionVT.getSimpleVT();
7373
}
74+
75+
bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx,
76+
unsigned OpIdx) const;
77+
bool insertLogicalCopyOnResult(MachineInstr &I,
78+
SPIRVType *NewResultType) const;
7479
};
7580
} // namespace llvm
7681

llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@ declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handle
1111

1212
; CHECK: OpDecorate [[BufferVar:%.+]] DescriptorSet 0
1313
; CHECK: OpDecorate [[BufferVar]] Binding 0
14-
; CHECK: OpDecorate [[BufferType:%.+]] Block
15-
; CHECK: OpMemberDecorate [[BufferType]] 0 Offset 0
14+
; CHECK: OpMemberDecorate [[BufferType:%.+]] 0 Offset 0
15+
; CHECK: OpDecorate [[BufferType]] Block
1616
; CHECK: OpMemberDecorate [[BufferType]] 0 NonWritable
1717
; CHECK: OpDecorate [[RWBufferVar:%.+]] DescriptorSet 0
1818
; CHECK: OpDecorate [[RWBufferVar]] Binding 1
19-
; CHECK: OpDecorate [[RWBufferType:%.+]] Block
20-
; CHECK: OpMemberDecorate [[RWBufferType]] 0 Offset 0
19+
; CHECK: OpDecorate [[ArrayType:%.+]] ArrayStride 4
20+
; CHECK: OpMemberDecorate [[RWBufferType:%.+]] 0 Offset 0
21+
; CHECK: OpDecorate [[RWBufferType]] Block
2122

2223

2324
; CHECK: [[int:%[0-9]+]] = OpTypeInt 32 0
24-
; CHECK: [[ArrayType:%.+]] = OpTypeRuntimeArray
25+
; CHECK: [[ArrayType]] = OpTypeRuntimeArray
2526
; CHECK: [[RWBufferType]] = OpTypeStruct [[ArrayType]]
2627
; CHECK: [[RWBufferPtrType:%.+]] = OpTypePointer StorageBuffer [[RWBufferType]]
2728
; CHECK: [[BufferType]] = OpTypeStruct [[ArrayType]]

0 commit comments

Comments
 (0)