Skip to content

Commit 6cce67a

Browse files
[SPIR-V] Fix validity of atomic instructions (#87051)
This PR fixes validity of atomic instructions and improves type inference. More tests are able now to be accepted by `spirv-val`.
1 parent 77e5c0a commit 6cce67a

File tree

9 files changed

+252
-37
lines changed

9 files changed

+252
-37
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,13 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
251251
cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
252252
}
253253

254+
// Replace PointerType with TypedPointerType to be able to map SPIR-V types to
255+
// LLVM types in a consistent manner
256+
if (isUntypedPointerTy(OriginalArgType)) {
257+
OriginalArgType =
258+
TypedPointerType::get(Type::getInt8Ty(F.getContext()),
259+
getPointerAddressSpace(OriginalArgType));
260+
}
254261
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
255262
}
256263

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ class SPIRVEmitIntrinsics
6565
Type *deduceElementType(Value *I);
6666
Type *deduceElementTypeHelper(Value *I);
6767
Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited);
68+
Type *deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand,
69+
std::unordered_set<Value *> &Visited);
70+
Type *deduceElementTypeByUsersDeep(Value *Op,
71+
std::unordered_set<Value *> &Visited);
6872

6973
// deduce nested types of composites
7074
Type *deduceNestedTypeHelper(User *U);
@@ -176,6 +180,44 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
176180
false);
177181
}
178182

183+
// Set element pointer type to the given value of ValueTy and tries to
184+
// specify this type further (recursively) by Operand value, if needed.
185+
Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
186+
Type *ValueTy, Value *Operand, std::unordered_set<Value *> &Visited) {
187+
Type *Ty = ValueTy;
188+
if (Operand) {
189+
if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
190+
if (Type *NestedTy = deduceElementTypeHelper(Operand, Visited))
191+
Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
192+
} else {
193+
Ty = deduceNestedTypeHelper(dyn_cast<User>(Operand), Ty, Visited);
194+
}
195+
}
196+
return Ty;
197+
}
198+
199+
// Traverse User instructions to deduce an element pointer type of the operand.
200+
Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
201+
Value *Op, std::unordered_set<Value *> &Visited) {
202+
if (!Op || !isPointerTy(Op->getType()))
203+
return nullptr;
204+
205+
if (auto PType = dyn_cast<TypedPointerType>(Op->getType()))
206+
return PType->getElementType();
207+
208+
// maybe we already know operand's element type
209+
if (Type *KnownTy = GR->findDeducedElementType(Op))
210+
return KnownTy;
211+
212+
for (User *OpU : Op->users()) {
213+
if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
214+
if (Type *Ty = deduceElementTypeHelper(Inst, Visited))
215+
return Ty;
216+
}
217+
}
218+
return nullptr;
219+
}
220+
179221
// Deduce and return a successfully deduced Type of the Instruction,
180222
// or nullptr otherwise.
181223
Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
@@ -206,21 +248,27 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
206248
} else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
207249
Ty = Ref->getResultElementType();
208250
} else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
209-
Ty = Ref->getValueType();
210-
if (Value *Op = Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr) {
211-
if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
212-
if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
213-
Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
214-
} else {
215-
Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), Ty, Visited);
216-
}
217-
}
251+
Ty = deduceElementTypeByValueDeep(
252+
Ref->getValueType(),
253+
Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited);
218254
} else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
219255
Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited);
220256
} else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
221257
if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
222258
isPointerTy(Src) && isPointerTy(Dest))
223259
Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited);
260+
} else if (auto *Ref = dyn_cast<AtomicCmpXchgInst>(I)) {
261+
Value *Op = Ref->getNewValOperand();
262+
Ty = deduceElementTypeByValueDeep(Op->getType(), Op, Visited);
263+
} else if (auto *Ref = dyn_cast<AtomicRMWInst>(I)) {
264+
Value *Op = Ref->getValOperand();
265+
Ty = deduceElementTypeByValueDeep(Op->getType(), Op, Visited);
266+
} else if (auto *Ref = dyn_cast<PHINode>(I)) {
267+
for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
268+
Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited);
269+
if (Ty)
270+
break;
271+
}
224272
}
225273

226274
// remember the found relationship
@@ -293,6 +341,22 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
293341
return NewTy;
294342
}
295343
}
344+
} else if (auto *VecTy = dyn_cast<VectorType>(OrigTy)) {
345+
if (Value *Op = U->getNumOperands() > 0 ? U->getOperand(0) : nullptr) {
346+
Type *OpTy = VecTy->getElementType();
347+
Type *Ty = OpTy;
348+
if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
349+
if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
350+
Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
351+
} else {
352+
Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
353+
}
354+
if (Ty != OpTy) {
355+
Type *NewTy = VectorType::get(Ty, VecTy->getElementCount());
356+
GR->addDeducedCompositeType(U, NewTy);
357+
return NewTy;
358+
}
359+
}
296360
}
297361

298362
return OrigTy;
@@ -578,7 +642,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
578642

579643
// Handle calls to builtins (non-intrinsics):
580644
CallInst *CI = dyn_cast<CallInst>(I);
581-
if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic())
645+
if (!CI || CI->isIndirectCall() || CI->isInlineAsm() ||
646+
!CI->getCalledFunction() || CI->getCalledFunction()->isIntrinsic())
582647
return;
583648

584649
// collect information about formal parameter types
@@ -929,6 +994,10 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
929994
// maybe we already know operand's element type
930995
if (Type *KnownTy = GR->findDeducedElementType(OpArg))
931996
return KnownTy;
997+
// try to deduce from the operand itself
998+
Visited.clear();
999+
if (Type *Ty = deduceElementTypeHelper(OpArg, Visited))
1000+
return Ty;
9321001
// search in actual parameter's users
9331002
for (User *OpU : OpArg->users()) {
9341003
Instruction *Inst = dyn_cast<Instruction>(OpU);

llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ void validateForwardCalls(const SPIRVSubtarget &STI,
201201
}
202202
}
203203

204+
// Validation of an access chain.
205+
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
206+
SPIRVGlobalRegistry &GR, MachineInstr &I) {
207+
SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
208+
if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
209+
SPIRVType *BaseElemType =
210+
GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
211+
validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
212+
}
213+
}
214+
204215
// TODO: the logic of inserting additional bitcast's is to be moved
205216
// to pre-IRTranslation passes eventually
206217
void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
@@ -213,16 +224,47 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
213224
MBBI != MBBE;) {
214225
MachineInstr &MI = *MBBI++;
215226
switch (MI.getOpcode()) {
227+
case SPIRV::OpAtomicLoad:
228+
case SPIRV::OpAtomicExchange:
229+
case SPIRV::OpAtomicCompareExchange:
230+
case SPIRV::OpAtomicCompareExchangeWeak:
231+
case SPIRV::OpAtomicIIncrement:
232+
case SPIRV::OpAtomicIDecrement:
233+
case SPIRV::OpAtomicIAdd:
234+
case SPIRV::OpAtomicISub:
235+
case SPIRV::OpAtomicSMin:
236+
case SPIRV::OpAtomicUMin:
237+
case SPIRV::OpAtomicSMax:
238+
case SPIRV::OpAtomicUMax:
239+
case SPIRV::OpAtomicAnd:
240+
case SPIRV::OpAtomicOr:
241+
case SPIRV::OpAtomicXor:
242+
// for the above listed instructions
243+
// OpAtomicXXX <ResType>, ptr %Op, ...
244+
// implies that %Op is a pointer to <ResType>
216245
case SPIRV::OpLoad:
217246
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
218247
validatePtrTypes(STI, MRI, GR, MI, 2,
219248
GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
220249
break;
250+
case SPIRV::OpAtomicStore:
251+
// OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
252+
// implies that %Op points to the <Obj>'s type
253+
validatePtrTypes(STI, MRI, GR, MI, 0,
254+
GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
255+
break;
221256
case SPIRV::OpStore:
222257
// OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
223258
validatePtrTypes(STI, MRI, GR, MI, 0,
224259
GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
225260
break;
261+
case SPIRV::OpPtrCastToGeneric:
262+
validateAccessChain(STI, MRI, GR, MI);
263+
break;
264+
case SPIRV::OpInBoundsPtrAccessChain:
265+
if (MI.getNumOperands() == 4)
266+
validateAccessChain(STI, MRI, GR, MI);
267+
break;
226268

227269
case SPIRV::OpFunctionCall:
228270
// ensure there is no mismatch between actual and expected arg types:

llvm/test/CodeGen/SPIRV/ExecutionMode.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
23

34
; CHECK-DAG: %[[#VOID:]] = OpTypeVoid
45

llvm/test/CodeGen/SPIRV/instructions/atomic.ll

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
23

34
; CHECK-DAG: OpName [[ADD:%.*]] "test_add"
45
; CHECK-DAG: OpName [[SUB:%.*]] "test_sub"
@@ -20,7 +21,8 @@
2021
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
2122
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
2223
; CHECK-NEXT: OpLabel
23-
; CHECK-NEXT: [[R:%.*]] = OpAtomicIAdd [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
24+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
25+
; CHECK-NEXT: [[R:%.*]] = OpAtomicIAdd [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
2426
; CHECK-NEXT: OpReturnValue [[R]]
2527
; CHECK-NEXT: OpFunctionEnd
2628
define i32 @test_add(i32* %ptr, i32 %val) {
@@ -32,7 +34,8 @@ define i32 @test_add(i32* %ptr, i32 %val) {
3234
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
3335
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
3436
; CHECK-NEXT: OpLabel
35-
; CHECK-NEXT: [[R:%.*]] = OpAtomicISub [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
37+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
38+
; CHECK-NEXT: [[R:%.*]] = OpAtomicISub [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
3639
; CHECK-NEXT: OpReturnValue [[R]]
3740
; CHECK-NEXT: OpFunctionEnd
3841
define i32 @test_sub(i32* %ptr, i32 %val) {
@@ -44,7 +47,8 @@ define i32 @test_sub(i32* %ptr, i32 %val) {
4447
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
4548
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
4649
; CHECK-NEXT: OpLabel
47-
; CHECK-NEXT: [[R:%.*]] = OpAtomicSMin [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
50+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
51+
; CHECK-NEXT: [[R:%.*]] = OpAtomicSMin [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
4852
; CHECK-NEXT: OpReturnValue [[R]]
4953
; CHECK-NEXT: OpFunctionEnd
5054
define i32 @test_min(i32* %ptr, i32 %val) {
@@ -56,7 +60,8 @@ define i32 @test_min(i32* %ptr, i32 %val) {
5660
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
5761
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
5862
; CHECK-NEXT: OpLabel
59-
; CHECK-NEXT: [[R:%.*]] = OpAtomicSMax [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
63+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
64+
; CHECK-NEXT: [[R:%.*]] = OpAtomicSMax [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
6065
; CHECK-NEXT: OpReturnValue [[R]]
6166
; CHECK-NEXT: OpFunctionEnd
6267
define i32 @test_max(i32* %ptr, i32 %val) {
@@ -68,7 +73,8 @@ define i32 @test_max(i32* %ptr, i32 %val) {
6873
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
6974
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
7075
; CHECK-NEXT: OpLabel
71-
; CHECK-NEXT: [[R:%.*]] = OpAtomicUMin [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
76+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
77+
; CHECK-NEXT: [[R:%.*]] = OpAtomicUMin [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
7278
; CHECK-NEXT: OpReturnValue [[R]]
7379
; CHECK-NEXT: OpFunctionEnd
7480
define i32 @test_umin(i32* %ptr, i32 %val) {
@@ -80,7 +86,8 @@ define i32 @test_umin(i32* %ptr, i32 %val) {
8086
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
8187
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
8288
; CHECK-NEXT: OpLabel
83-
; CHECK-NEXT: [[R:%.*]] = OpAtomicUMax [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
89+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
90+
; CHECK-NEXT: [[R:%.*]] = OpAtomicUMax [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
8491
; CHECK-NEXT: OpReturnValue [[R]]
8592
; CHECK-NEXT: OpFunctionEnd
8693
define i32 @test_umax(i32* %ptr, i32 %val) {
@@ -92,7 +99,8 @@ define i32 @test_umax(i32* %ptr, i32 %val) {
9299
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
93100
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
94101
; CHECK-NEXT: OpLabel
95-
; CHECK-NEXT: [[R:%.*]] = OpAtomicAnd [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
102+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
103+
; CHECK-NEXT: [[R:%.*]] = OpAtomicAnd [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
96104
; CHECK-NEXT: OpReturnValue [[R]]
97105
; CHECK-NEXT: OpFunctionEnd
98106
define i32 @test_and(i32* %ptr, i32 %val) {
@@ -104,7 +112,8 @@ define i32 @test_and(i32* %ptr, i32 %val) {
104112
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
105113
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
106114
; CHECK-NEXT: OpLabel
107-
; CHECK-NEXT: [[R:%.*]] = OpAtomicOr [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
115+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
116+
; CHECK-NEXT: [[R:%.*]] = OpAtomicOr [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
108117
; CHECK-NEXT: OpReturnValue [[R]]
109118
; CHECK-NEXT: OpFunctionEnd
110119
define i32 @test_or(i32* %ptr, i32 %val) {
@@ -116,7 +125,8 @@ define i32 @test_or(i32* %ptr, i32 %val) {
116125
; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
117126
; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
118127
; CHECK-NEXT: OpLabel
119-
; CHECK-NEXT: [[R:%.*]] = OpAtomicXor [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
128+
; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
129+
; CHECK-NEXT: [[R:%.*]] = OpAtomicXor [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
120130
; CHECK-NEXT: OpReturnValue [[R]]
121131
; CHECK-NEXT: OpFunctionEnd
122132
define i32 @test_xor(i32* %ptr, i32 %val) {

0 commit comments

Comments
 (0)