Skip to content

Commit 44cfb6b

Browse files
[SPIR-V] Ensure that a correct pointer type is deduced from the Value argument of OpAtomic* instructions (llvm#127492)
This PR improves the set of rules for type inference by ensuring that a correct pointer type is deduced from the Value argument of OpAtomic* instructions, also when a pointer argument is coming from an `inttoptr .. to` instruction that caused problems earlier. Existing test cases are updated accordingly. This fixes llvm#127491
1 parent a377cdd commit 44cfb6b

File tree

3 files changed

+129
-43
lines changed

3 files changed

+129
-43
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class SPIRVEmitIntrinsics
135135

136136
// deduce Types of operands of the Instruction if possible
137137
void deduceOperandElementType(Instruction *I,
138-
SmallPtrSet<Instruction *, 4> *UncompleteRets,
138+
SmallPtrSet<Instruction *, 4> *IncompleteRets,
139139
const SmallPtrSet<Value *, 4> *AskOps = nullptr,
140140
bool IsPostprocessing = false);
141141

@@ -182,12 +182,12 @@ class SPIRVEmitIntrinsics
182182

183183
bool deduceOperandElementTypeCalledFunction(
184184
CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
185-
Type *&KnownElemTy);
185+
Type *&KnownElemTy, bool &Incomplete);
186186
void deduceOperandElementTypeFunctionPointer(
187187
CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
188188
Type *&KnownElemTy, bool IsPostprocessing);
189189
bool deduceOperandElementTypeFunctionRet(
190-
Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
190+
Instruction *I, SmallPtrSet<Instruction *, 4> *IncompleteRets,
191191
const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing,
192192
Type *&KnownElemTy, Value *Op, Function *F);
193193

@@ -893,7 +893,7 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
893893
// indirect function invocation, and true otherwise.
894894
bool SPIRVEmitIntrinsics::deduceOperandElementTypeCalledFunction(
895895
CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
896-
Type *&KnownElemTy) {
896+
Type *&KnownElemTy, bool &Incomplete) {
897897
Function *CalledF = CI->getCalledFunction();
898898
if (!CalledF)
899899
return false;
@@ -915,12 +915,15 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeCalledFunction(
915915
Ops.push_back(std::make_pair(Op, i));
916916
}
917917
} else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) {
918-
if (CI->arg_size() < 2)
918+
if (CI->arg_size() == 0)
919919
return true;
920920
Value *Op = CI->getArgOperand(0);
921921
if (!isPointerTy(Op->getType()))
922922
return true;
923923
switch (Opcode) {
924+
case SPIRV::OpAtomicFAddEXT:
925+
case SPIRV::OpAtomicFMinEXT:
926+
case SPIRV::OpAtomicFMaxEXT:
924927
case SPIRV::OpAtomicLoad:
925928
case SPIRV::OpAtomicCompareExchangeWeak:
926929
case SPIRV::OpAtomicCompareExchange:
@@ -934,9 +937,23 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeCalledFunction(
934937
case SPIRV::OpAtomicUMax:
935938
case SPIRV::OpAtomicSMin:
936939
case SPIRV::OpAtomicSMax: {
937-
KnownElemTy = getAtomicElemTy(GR, CI, Op);
940+
KnownElemTy = isPointerTy(CI->getType()) ? getAtomicElemTy(GR, CI, Op)
941+
: CI->getType();
938942
if (!KnownElemTy)
939943
return true;
944+
Incomplete = isTodoType(Op);
945+
Ops.push_back(std::make_pair(Op, 0));
946+
} break;
947+
case SPIRV::OpAtomicStore: {
948+
if (CI->arg_size() < 4)
949+
return true;
950+
Value *ValOp = CI->getArgOperand(3);
951+
KnownElemTy = isPointerTy(ValOp->getType())
952+
? getAtomicElemTy(GR, CI, Op)
953+
: ValOp->getType();
954+
if (!KnownElemTy)
955+
return true;
956+
Incomplete = isTodoType(Op);
940957
Ops.push_back(std::make_pair(Op, 0));
941958
} break;
942959
}
@@ -954,7 +971,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
954971
return;
955972
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
956973
FunctionType *FTy = CI->getFunctionType();
957-
bool IsNewFTy = false, IsUncomplete = false;
974+
bool IsNewFTy = false, IsIncomplete = false;
958975
SmallVector<Type *, 4> ArgTys;
959976
for (Value *Arg : CI->args()) {
960977
Type *ArgTy = Arg->getType();
@@ -963,9 +980,9 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
963980
IsNewFTy = true;
964981
ArgTy = getTypedPointerWrapper(ElemTy, getPointerAddressSpace(ArgTy));
965982
if (isTodoType(Arg))
966-
IsUncomplete = true;
983+
IsIncomplete = true;
967984
} else {
968-
IsUncomplete = true;
985+
IsIncomplete = true;
969986
}
970987
}
971988
ArgTys.push_back(ArgTy);
@@ -977,19 +994,19 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
977994
RetTy =
978995
getTypedPointerWrapper(ElemTy, getPointerAddressSpace(CI->getType()));
979996
if (isTodoType(CI))
980-
IsUncomplete = true;
997+
IsIncomplete = true;
981998
} else {
982-
IsUncomplete = true;
999+
IsIncomplete = true;
9831000
}
9841001
}
985-
if (!IsPostprocessing && IsUncomplete)
1002+
if (!IsPostprocessing && IsIncomplete)
9861003
insertTodoType(Op);
9871004
KnownElemTy =
9881005
IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
9891006
}
9901007

9911008
bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
992-
Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
1009+
Instruction *I, SmallPtrSet<Instruction *, 4> *IncompleteRets,
9931010
const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing,
9941011
Type *&KnownElemTy, Value *Op, Function *F) {
9951012
KnownElemTy = GR->findDeducedElementType(F);
@@ -1018,13 +1035,13 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
10181035
// This may happen just once per a function, the latch is a pair of
10191036
// findDeducedElementType(F) / addDeducedElementType(F, ...).
10201037
// With or without the latch it is a non-recursive call due to
1021-
// UncompleteRets set to nullptr in this call.
1022-
if (UncompleteRets)
1023-
for (Instruction *UncompleteRetI : *UncompleteRets)
1024-
deduceOperandElementType(UncompleteRetI, nullptr, AskOps,
1038+
// IncompleteRets set to nullptr in this call.
1039+
if (IncompleteRets)
1040+
for (Instruction *IncompleteRetI : *IncompleteRets)
1041+
deduceOperandElementType(IncompleteRetI, nullptr, AskOps,
10251042
IsPostprocessing);
1026-
} else if (UncompleteRets) {
1027-
UncompleteRets->insert(I);
1043+
} else if (IncompleteRets) {
1044+
IncompleteRets->insert(I);
10281045
}
10291046
TypeValidated.insert(I);
10301047
return true;
@@ -1035,17 +1052,17 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
10351052
// types which differ from expected, this function tries to insert a bitcast to
10361053
// resolve the issue.
10371054
void SPIRVEmitIntrinsics::deduceOperandElementType(
1038-
Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
1055+
Instruction *I, SmallPtrSet<Instruction *, 4> *IncompleteRets,
10391056
const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing) {
10401057
SmallVector<std::pair<Value *, unsigned>> Ops;
10411058
Type *KnownElemTy = nullptr;
1042-
bool Uncomplete = false;
1059+
bool Incomplete = false;
10431060
// look for known basic patterns of type inference
10441061
if (auto *Ref = dyn_cast<PHINode>(I)) {
10451062
if (!isPointerTy(I->getType()) ||
10461063
!(KnownElemTy = GR->findDeducedElementType(I)))
10471064
return;
1048-
Uncomplete = isTodoType(I);
1065+
Incomplete = isTodoType(I);
10491066
for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
10501067
Value *Op = Ref->getIncomingValue(i);
10511068
if (isPointerTy(Op->getType()))
@@ -1055,15 +1072,15 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
10551072
KnownElemTy = GR->findDeducedElementType(I);
10561073
if (!KnownElemTy)
10571074
return;
1058-
Uncomplete = isTodoType(I);
1075+
Incomplete = isTodoType(I);
10591076
Ops.push_back(std::make_pair(Ref->getPointerOperand(), 0));
10601077
} else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
10611078
if (!isPointerTy(I->getType()))
10621079
return;
10631080
KnownElemTy = GR->findDeducedElementType(I);
10641081
if (!KnownElemTy)
10651082
return;
1066-
Uncomplete = isTodoType(I);
1083+
Incomplete = isTodoType(I);
10671084
Ops.push_back(std::make_pair(Ref->getOperand(0), 0));
10681085
} else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
10691086
if (GR->findDeducedElementType(Ref->getPointerOperand()))
@@ -1090,22 +1107,28 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
10901107
Ops.push_back(std::make_pair(Ref->getPointerOperand(),
10911108
StoreInst::getPointerOperandIndex()));
10921109
} else if (auto *Ref = dyn_cast<AtomicCmpXchgInst>(I)) {
1093-
KnownElemTy = getAtomicElemTy(GR, I, Ref->getPointerOperand());
1110+
KnownElemTy = isPointerTy(I->getType())
1111+
? getAtomicElemTy(GR, I, Ref->getPointerOperand())
1112+
: I->getType();
10941113
if (!KnownElemTy)
10951114
return;
1115+
Incomplete = isTodoType(Ref->getPointerOperand());
10961116
Ops.push_back(std::make_pair(Ref->getPointerOperand(),
10971117
AtomicCmpXchgInst::getPointerOperandIndex()));
10981118
} else if (auto *Ref = dyn_cast<AtomicRMWInst>(I)) {
1099-
KnownElemTy = getAtomicElemTy(GR, I, Ref->getPointerOperand());
1119+
KnownElemTy = isPointerTy(I->getType())
1120+
? getAtomicElemTy(GR, I, Ref->getPointerOperand())
1121+
: I->getType();
11001122
if (!KnownElemTy)
11011123
return;
1124+
Incomplete = isTodoType(Ref->getPointerOperand());
11021125
Ops.push_back(std::make_pair(Ref->getPointerOperand(),
11031126
AtomicRMWInst::getPointerOperandIndex()));
11041127
} else if (auto *Ref = dyn_cast<SelectInst>(I)) {
11051128
if (!isPointerTy(I->getType()) ||
11061129
!(KnownElemTy = GR->findDeducedElementType(I)))
11071130
return;
1108-
Uncomplete = isTodoType(I);
1131+
Incomplete = isTodoType(I);
11091132
for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
11101133
Value *Op = Ref->getOperand(i);
11111134
if (isPointerTy(Op->getType()))
@@ -1117,11 +1140,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
11171140
Value *Op = Ref->getReturnValue();
11181141
if (!Op)
11191142
return;
1120-
if (deduceOperandElementTypeFunctionRet(I, UncompleteRets, AskOps,
1143+
if (deduceOperandElementTypeFunctionRet(I, IncompleteRets, AskOps,
11211144
IsPostprocessing, KnownElemTy, Op,
11221145
CurrF))
11231146
return;
1124-
Uncomplete = isTodoType(CurrF);
1147+
Incomplete = isTodoType(CurrF);
11251148
Ops.push_back(std::make_pair(Op, 0));
11261149
} else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
11271150
if (!isPointerTy(Ref->getOperand(0)->getType()))
@@ -1132,16 +1155,16 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
11321155
Type *ElemTy1 = GR->findDeducedElementType(Op1);
11331156
if (ElemTy0) {
11341157
KnownElemTy = ElemTy0;
1135-
Uncomplete = isTodoType(Op0);
1158+
Incomplete = isTodoType(Op0);
11361159
Ops.push_back(std::make_pair(Op1, 1));
11371160
} else if (ElemTy1) {
11381161
KnownElemTy = ElemTy1;
1139-
Uncomplete = isTodoType(Op1);
1162+
Incomplete = isTodoType(Op1);
11401163
Ops.push_back(std::make_pair(Op0, 0));
11411164
}
11421165
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
11431166
if (!CI->isIndirectCall())
1144-
deduceOperandElementTypeCalledFunction(CI, Ops, KnownElemTy);
1167+
deduceOperandElementTypeCalledFunction(CI, Ops, KnownElemTy, Incomplete);
11451168
else if (HaveFunPtrs)
11461169
deduceOperandElementTypeFunctionPointer(CI, Ops, KnownElemTy,
11471170
IsPostprocessing);
@@ -1175,7 +1198,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
11751198
Type *PrevElemTy = GR->findDeducedElementType(Op);
11761199
GR->addDeducedElementType(Op, normalizeType(KnownElemTy));
11771200
// check if KnownElemTy is complete
1178-
if (!Uncomplete)
1201+
if (!Incomplete)
11791202
eraseTodoType(Op);
11801203
else if (!IsPostprocessing)
11811204
insertTodoType(Op);
@@ -2394,9 +2417,9 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
23942417

23952418
// Pass backward: use instructions results to specify/update/cast operands
23962419
// where needed.
2397-
SmallPtrSet<Instruction *, 4> UncompleteRets;
2420+
SmallPtrSet<Instruction *, 4> IncompleteRets;
23982421
for (auto &I : llvm::reverse(instructions(Func)))
2399-
deduceOperandElementType(&I, &UncompleteRets);
2422+
deduceOperandElementType(&I, &IncompleteRets);
24002423

24012424
// Pass forward for PHIs only, their operands are not preceed the instruction
24022425
// in meaning of `instructions(Func)`.
@@ -2465,15 +2488,15 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
24652488

24662489
for (auto &F : M) {
24672490
CurrF = &F;
2468-
SmallPtrSet<Instruction *, 4> UncompleteRets;
2491+
SmallPtrSet<Instruction *, 4> IncompleteRets;
24692492
for (auto &I : llvm::reverse(instructions(F))) {
24702493
auto It = ToProcess.find(&I);
24712494
if (It == ToProcess.end())
24722495
continue;
24732496
It->second.remove_if([this](Value *V) { return !isTodoType(V); });
24742497
if (It->second.size() == 0)
24752498
continue;
2476-
deduceOperandElementType(&I, &UncompleteRets, &It->second, true);
2499+
deduceOperandElementType(&I, &IncompleteRets, &It->second, true);
24772500
if (TodoTypeSz == 0)
24782501
return true;
24792502
}

llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
22

33
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add %s -o - | FileCheck %s
4+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add %s -o - -filetype=obj | spirv-val %}
5+
6+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add %s -o - | FileCheck %s
7+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add %s -o - -filetype=obj | spirv-val %}
48

59
; CHECK-ERROR: LLVM ERROR: The atomic float instruction requires the following SPIR-V extension: SPV_EXT_shader_atomic_float_add
610

@@ -25,9 +29,6 @@
2529
; CHECK: %[[Neg42:[0-9]+]] = OpFNegate %[[TyFP32]] %[[Const42]]
2630
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeWorkgroup]] %[[WorkgroupMemory]] %[[Neg42]]
2731

28-
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
29-
target triple = "spir64"
30-
3132
@f = common dso_local local_unnamed_addr addrspace(1) global float 0.000000e+00, align 8
3233

3334
define dso_local spir_func void @test1() local_unnamed_addr {
@@ -55,5 +56,31 @@ entry:
5556
declare spir_func float @_Z25atomic_fetch_add_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1), float, i32)
5657
declare spir_func float @_Z25atomic_fetch_sub_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1), float, i32)
5758

59+
; CHECK: %[[#Ptr1:]] = OpConvertUToPtr %[[TyFP32Ptr]] %[[#]]
60+
; CHECK: %[[#]] = OpAtomicFAddEXT %[[TyFP32]] %[[#Ptr1]] %[[#]] %[[#]] %[[#]]
61+
; CHECK: %[[#Ptr2:]] = OpConvertUToPtr %[[TyFP32Ptr]] %[[#]]
62+
; CHECK: %[[#]] = OpAtomicFAddEXT %[[TyFP32]] %[[#Ptr2]] %[[#]] %[[#]] %[[#]]
63+
; CHECK: %[[#Ptr3:]] = OpConvertUToPtr %[[TyFP32Ptr]] %[[#]]
64+
; CHECK: %[[#]] = OpAtomicFAddEXT %[[TyFP32]] %[[#Ptr3]] %[[#]] %[[#]] %[[#]]
65+
; CHECK: %[[#Ptr4:]] = OpConvertUToPtr %[[TyFP32Ptr]] %[[#]]
66+
; CHECK: %[[#]] = OpAtomicFAddEXT %[[TyFP32]] %[[#Ptr4]] %[[#]] %[[#]] %[[#]]
67+
; CHECK: %[[#Ptr5:]] = OpConvertUToPtr %[[TyFP32Ptr]] %[[#]]
68+
; CHECK: %[[#]] = OpAtomicFAddEXT %[[TyFP32]] %[[#Ptr5]] %[[#]] %[[#]] %[[#]]
69+
70+
define dso_local spir_func void @test4(i64 noundef %arg, float %val) local_unnamed_addr {
71+
entry:
72+
%ptr1 = inttoptr i64 %arg to float addrspace(1)*
73+
%v1 = atomicrmw fadd ptr addrspace(1) %ptr1, float %val seq_cst, align 4
74+
%ptr2 = inttoptr i64 %arg to float addrspace(1)*
75+
%v2 = atomicrmw fsub ptr addrspace(1) %ptr2, float %val seq_cst, align 4
76+
%ptr3 = inttoptr i64 %arg to float addrspace(1)*
77+
%v3 = tail call spir_func float @_Z21__spirv_AtomicFAddEXT(ptr addrspace(1) %ptr3, i32 1, i32 16, float %val)
78+
%ptr4 = inttoptr i64 %arg to float addrspace(1)*
79+
%v4 = tail call spir_func float @_Z25atomic_fetch_add_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1) %ptr4, float %val, i32 0)
80+
%ptr5 = inttoptr i64 %arg to float addrspace(1)*
81+
%v5 = tail call spir_func float @_Z25atomic_fetch_sub_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1) %ptr5, float %val, i32 0)
82+
ret void
83+
}
84+
5885
!llvm.module.flags = !{!0}
5986
!0 = !{i32 1, !"wchar_size", i32 4}

0 commit comments

Comments
 (0)