Skip to content

Commit c78bbd0

Browse files
AlexeySotkinvladimirlaz
authored andcommitted
Translate LLVM's cmpxchg instruction to SPIR-V
Signed-off-by: Alexey Sotkin <[email protected]>
1 parent f8cb9a5 commit c78bbd0

File tree

2 files changed

+114
-6
lines changed

2 files changed

+114
-6
lines changed

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "llvm/Support/Debug.h"
5050

5151
#include <set>
52+
#include <vector>
5253

5354
using namespace llvm;
5455
using namespace SPIRV;
@@ -112,17 +113,18 @@ bool SPIRVRegularizeLLVM::regularize() {
112113
continue;
113114
}
114115

115-
for (auto BI = F->begin(), BE = F->end(); BI != BE; ++BI) {
116-
for (auto II = BI->begin(), IE = BI->end(); II != IE; ++II) {
117-
if (auto Call = dyn_cast<CallInst>(II)) {
116+
std::vector<Instruction *> ToErase;
117+
for (BasicBlock &BB : *F) {
118+
for (Instruction &II : BB) {
119+
if (auto Call = dyn_cast<CallInst>(&II)) {
118120
Call->setTailCall(false);
119121
Function *CF = Call->getCalledFunction();
120122
if (CF && CF->isIntrinsic())
121123
removeFnAttr(Call, Attribute::NoUnwind);
122124
}
123125

124126
// Remove optimization info not supported by SPIRV
125-
if (auto BO = dyn_cast<BinaryOperator>(II)) {
127+
if (auto BO = dyn_cast<BinaryOperator>(&II)) {
126128
if (isa<PossiblyExactOperator>(BO) && BO->isExact())
127129
BO->setIsExact(false);
128130
}
@@ -133,12 +135,68 @@ bool SPIRVRegularizeLLVM::regularize() {
133135
"range",
134136
};
135137
for (auto &MDName : MDs) {
136-
if (II->getMetadata(MDName)) {
137-
II->setMetadata(MDName, nullptr);
138+
if (II.getMetadata(MDName)) {
139+
II.setMetadata(MDName, nullptr);
138140
}
139141
}
142+
if (auto Cmpxchg = dyn_cast<AtomicCmpXchgInst>(&II)) {
143+
Value *Ptr = Cmpxchg->getPointerOperand();
144+
// To get memory scope argument we might use Cmpxchg->getSyncScopeID()
145+
// but LLVM's cmpxchg instruction is not aware of OpenCL(or SPIR-V)
146+
// memory scope enumeration. And assuming the produced SPIR-V module
147+
// will be consumed in an OpenCL environment, we can use the same
148+
// memory scope as OpenCL atomic functions that do not have
149+
// memory_scope argument, i.e. memory_scope_device. See the OpenCL C
150+
// specification p6.13.11. Atomic Functions
151+
Value *MemoryScope = getInt32(M, spv::ScopeDevice);
152+
auto SuccessOrder = static_cast<OCLMemOrderKind>(
153+
llvm::toCABI(Cmpxchg->getSuccessOrdering()));
154+
auto FailureOrder = static_cast<OCLMemOrderKind>(
155+
llvm::toCABI(Cmpxchg->getFailureOrdering()));
156+
Value *EqualSem = getInt32(M, OCLMemOrderMap::map(SuccessOrder));
157+
Value *UnequalSem = getInt32(M, OCLMemOrderMap::map(FailureOrder));
158+
Value *Val = Cmpxchg->getNewValOperand();
159+
Value *Comparator = Cmpxchg->getCompareOperand();
160+
161+
llvm::Value *Args[] = {Ptr, MemoryScope, EqualSem,
162+
UnequalSem, Val, Comparator};
163+
auto *Res = addCallInstSPIRV(M, "__spirv_AtomicCompareExchange",
164+
Cmpxchg->getCompareOperand()->getType(),
165+
Args, nullptr, &II, "cmpxchg.res");
166+
// cmpxchg LLVM instruction returns a pair: the original value and
167+
// a flag indicating success (true) or failure (false).
168+
// OpAtomicCompareExchange SPIR-V instruction returns only the
169+
// original value. So we replace all uses of the original value
170+
// extracted from the pair with the result of OpAtomicCompareExchange
171+
// instruction. And we replace all uses of the flag with result of an
172+
// OpIEqual instruction. The OpIEqual instruction returns true if the
173+
// original value equals to the comparator which matches with
174+
// semantics of cmpxchg.
175+
for (User *U : Cmpxchg->users()) {
176+
if (auto *Extract = dyn_cast<ExtractValueInst>(U)) {
177+
if (Extract->getIndices()[0] == 0) {
178+
Extract->replaceAllUsesWith(Res);
179+
} else if (Extract->getIndices()[0] == 1) {
180+
auto *Cmp = new ICmpInst(Extract, CmpInst::ICMP_EQ, Res,
181+
Comparator, "cmpxchg.success");
182+
Extract->replaceAllUsesWith(Cmp);
183+
} else {
184+
llvm_unreachable("Unxpected cmpxchg pattern");
185+
}
186+
assert(Extract->user_empty());
187+
Extract->dropAllReferences();
188+
ToErase.push_back(Extract);
189+
}
190+
}
191+
if (Cmpxchg->user_empty())
192+
ToErase.push_back(Cmpxchg);
193+
}
140194
}
141195
}
196+
for (Instruction *V : ToErase) {
197+
assert(V->user_empty());
198+
V->eraseFromParent();
199+
}
142200
}
143201

144202
std::string Err;
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck %s --check-prefix=CHECK-SPIRV
4+
; RUN: spirv-val %t.spv
5+
6+
; CHECK-SPIRV: TypeInt [[Int:[0-9]+]] 32 0
7+
; CHECK-SPIRV: Constant [[Int]] [[MemScope_Device:[0-9]+]] 1
8+
; CHECK-SPIRV: Constant [[Int]] [[MemSemEqual_SeqCst:[0-9]+]] 16
9+
; CHECK-SPIRV: Constant [[Int]] [[MemSemUnequal_Acquire:[0-9]+]] 2
10+
11+
; CHECK-SPIRV: FunctionParameter {{[0-9]+}} [[Pointer:[0-9]+]]
12+
; CHECK-SPIRV: FunctionParameter {{[0-9]+}} [[Value_ptr:[0-9]+]]
13+
; CHECK-SPIRV: FunctionParameter {{[0-9]+}} [[Comparator:[0-9]+]]
14+
15+
; CHECK-SPIRV: Load [[Int]] [[Value:[0-9]+]] [[Value_ptr]]
16+
; CHECK-SPIRV: AtomicCompareExchange [[Int]] [[Res:[0-9]+]] [[Pointer]] [[MemScope_Device]]
17+
; CHECK-SPIRV-SAME: [[MemSemEqual_SeqCst]] [[MemSemUnequal_Acquire]] [[Value]] [[Comparator]]
18+
; CHECK-SPIRV: IEqual {{[0-9]+}} [[Success:[0-9]+]] [[Res]] [[Comparator]]
19+
; CHECK-SPIRV: BranchConditional [[Success]]
20+
21+
; CHECK-SPIRV: Store [[Value_ptr]] [[Res]]
22+
23+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
24+
target triple = "spir"
25+
26+
; Function Attrs: nounwind
27+
define dso_local spir_func void @test(i32* %ptr, i32* %value_ptr, i32 %comparator) local_unnamed_addr #0 {
28+
entry:
29+
%0 = load i32, i32* %value_ptr, align 4
30+
%1 = cmpxchg i32* %ptr, i32 %comparator, i32 %0 seq_cst acquire
31+
%2 = extractvalue { i32, i1 } %1, 1
32+
br i1 %2, label %cmpxchg.continue, label %cmpxchg.store_expected
33+
34+
cmpxchg.store_expected: ; preds = %entry
35+
%3 = extractvalue { i32, i1 } %1, 0
36+
store i32 %3, i32* %value_ptr, align 4
37+
br label %cmpxchg.continue
38+
39+
cmpxchg.continue: ; preds = %cmpxchg.store_expected, %entry
40+
ret void
41+
}
42+
43+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
44+
attributes #1 = { nounwind }
45+
46+
!llvm.module.flags = !{!0}
47+
!llvm.ident = !{!1}
48+
49+
!0 = !{i32 1, !"wchar_size", i32 4}
50+
!1 = !{!"clang version 11.0.0 (https://github.com/llvm/llvm-project.git cfebd7774229885e7ec88ae9ef1f4ae819cce1d2)"}

0 commit comments

Comments
 (0)