Skip to content

Commit 9284a33

Browse files
[Mutator] Migrate atomics builtins to the mutator interface. (#1614)
The old mutateCallInst* functions and the new mutator doesn't quite get some of the value names correct, which necessitated some test changes that were looking for particular value names.
1 parent 139b08b commit 9284a33

File tree

7 files changed

+285
-452
lines changed

7 files changed

+285
-452
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 50 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -485,28 +485,23 @@ void OCLToSPIRVBase::visitCallAsyncWorkGroupCopy(CallInst *CI,
485485
}
486486

487487
CallInst *OCLToSPIRVBase::visitCallAtomicCmpXchg(CallInst *CI) {
488-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
489-
Value *Expected = nullptr;
490488
CallInst *NewCI = nullptr;
491-
mutateCallInstOCL(
492-
M, CI,
493-
[&](CallInst *CI, std::vector<Value *> &Args, Type *&RetTy) {
494-
Expected = Args[1]; // temporary save second argument.
495-
RetTy = Args[2]->getType();
496-
Args[1] = new LoadInst(RetTy, Args[1], "exp", false, CI);
497-
assert(Args[1]->getType()->isIntegerTy() &&
498-
Args[2]->getType()->isIntegerTy() &&
499-
"In SPIR-V 1.0 arguments of OpAtomicCompareExchange must be "
500-
"an integer type scalars");
501-
return kOCLBuiltinName::AtomicCmpXchgStrong;
502-
},
503-
[&](CallInst *NCI) -> Instruction * {
504-
NewCI = NCI;
505-
Instruction *Store = new StoreInst(NCI, Expected, NCI->getNextNode());
506-
return new ICmpInst(Store->getNextNode(), CmpInst::ICMP_EQ, NCI,
507-
NCI->getArgOperand(1));
508-
},
509-
&Attrs);
489+
{
490+
auto Mutator = mutateCallInst(CI, kOCLBuiltinName::AtomicCmpXchgStrong);
491+
Value *Expected = Mutator.getArg(1);
492+
Type *MemTy = Mutator.getArg(2)->getType();
493+
assert(MemTy->isIntegerTy() &&
494+
"In SPIR-V 1.0 arguments of OpAtomicCompareExchange must be "
495+
"an integer type scalars");
496+
Mutator.mapArg(1, [=](IRBuilder<> &Builder, Value *V) {
497+
return Builder.CreateLoad(MemTy, V, "exp");
498+
});
499+
Mutator.changeReturnType(MemTy, [&](IRBuilder<> &Builder, CallInst *NCI) {
500+
NewCI = NCI;
501+
Builder.CreateStore(NCI, Expected);
502+
return Builder.CreateICmpEQ(NCI, NCI->getArgOperand(1));
503+
});
504+
}
510505
return NewCI;
511506
}
512507

@@ -570,17 +565,10 @@ void OCLToSPIRVBase::visitCallMemFence(CallInst *CI, StringRef DemangledName) {
570565
void OCLToSPIRVBase::transMemoryBarrier(CallInst *CI,
571566
AtomicWorkItemFenceLiterals Lit) {
572567
assert(CI->getCalledFunction() && "Unexpected indirect call");
573-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
574-
mutateCallInstSPIRV(
575-
M, CI,
576-
[=](CallInst *, std::vector<Value *> &Args) {
577-
Args.resize(2);
578-
Args[0] = addInt32(map<Scope>(std::get<2>(Lit)));
579-
Args[1] = addInt32(
580-
mapOCLMemSemanticToSPIRV(std::get<0>(Lit), std::get<1>(Lit)));
581-
return getSPIRVFuncName(OpMemoryBarrier);
582-
},
583-
&Attrs);
568+
mutateCallInst(CI, OpMemoryBarrier)
569+
.setArgs({addInt32(map<Scope>(std::get<2>(Lit))),
570+
addInt32(mapOCLMemSemanticToSPIRV(std::get<0>(Lit),
571+
std::get<1>(Lit)))});
584572
}
585573

586574
void OCLToSPIRVBase::visitCallAtomicLegacy(CallInst *CI, StringRef MangledName,
@@ -747,25 +735,18 @@ void OCLToSPIRVBase::transAtomicBuiltin(CallInst *CI,
747735

748736
void OCLToSPIRVBase::visitCallBarrier(CallInst *CI) {
749737
auto Lit = getBarrierLiterals(CI);
750-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
751-
mutateCallInstSPIRV(
752-
M, CI,
753-
[=](CallInst *, std::vector<Value *> &Args) {
754-
Args.resize(3);
755-
// Execution scope
756-
Args[0] = addInt32(map<Scope>(std::get<2>(Lit)));
757-
// Memory scope
758-
Args[1] = addInt32(map<Scope>(std::get<1>(Lit)));
759-
// Use sequential consistent memory order by default.
760-
// But if the flags argument is set to 0, we use
761-
// None(Relaxed) memory order.
762-
unsigned MemFenceFlag = std::get<0>(Lit);
763-
OCLMemOrderKind MemOrder = MemFenceFlag ? OCLMO_seq_cst : OCLMO_relaxed;
764-
Args[2] = addInt32(mapOCLMemSemanticToSPIRV(
765-
MemFenceFlag, MemOrder)); // Memory semantics
766-
return getSPIRVFuncName(OpControlBarrier);
767-
},
768-
&Attrs);
738+
// Use sequential consistent memory order by default.
739+
// But if the flags argument is set to 0, we use
740+
// None(Relaxed) memory order.
741+
unsigned MemFenceFlag = std::get<0>(Lit);
742+
OCLMemOrderKind MemOrder = MemFenceFlag ? OCLMO_seq_cst : OCLMO_relaxed;
743+
mutateCallInst(CI, OpControlBarrier)
744+
.setArgs({// Execution scope
745+
addInt32(map<Scope>(std::get<2>(Lit))),
746+
// Memory scope
747+
addInt32(map<Scope>(std::get<1>(Lit))),
748+
// Memory semantics
749+
addInt32(mapOCLMemSemanticToSPIRV(MemFenceFlag, MemOrder))});
769750
}
770751

771752
void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
@@ -1300,19 +1281,12 @@ void OCLToSPIRVBase::visitCallVecLoadStore(CallInst *CI, StringRef MangledName,
13001281
}
13011282

13021283
void OCLToSPIRVBase::visitCallGetFence(CallInst *CI, StringRef DemangledName) {
1303-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
13041284
Op OC = OpNop;
13051285
OCLSPIRVBuiltinMap::find(DemangledName.str(), &OC);
1306-
std::string SPIRVName = getSPIRVFuncName(OC);
1307-
mutateCallInstSPIRV(
1308-
M, CI,
1309-
[=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1310-
return SPIRVName;
1311-
},
1312-
[=](CallInst *NewCI) -> Instruction * {
1313-
return BinaryOperator::CreateLShr(NewCI, getInt32(M, 8), "", CI);
1314-
},
1315-
&Attrs);
1286+
mutateCallInst(CI, OC).changeReturnType(
1287+
CI->getType(), [](IRBuilder<> &Builder, CallInst *NewCI) {
1288+
return Builder.CreateLShr(NewCI, Builder.getInt32(8));
1289+
});
13161290
}
13171291

13181292
void OCLToSPIRVBase::visitCallDot(CallInst *CI) {
@@ -1877,32 +1851,26 @@ void OCLToSPIRVBase::visitSubgroupAVCBuiltinCallWithSampler(
18771851
void OCLToSPIRVBase::visitCallSplitBarrierINTEL(CallInst *CI,
18781852
StringRef DemangledName) {
18791853
auto Lit = getBarrierLiterals(CI);
1880-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
18811854
Op OpCode =
18821855
StringSwitch<Op>(DemangledName)
18831856
.Case("intel_work_group_barrier_arrive", OpControlBarrierArriveINTEL)
18841857
.Case("intel_work_group_barrier_wait", OpControlBarrierWaitINTEL)
18851858
.Default(OpNop);
18861859

1887-
mutateCallInstSPIRV(
1888-
M, CI,
1889-
[=](CallInst *, std::vector<Value *> &Args) {
1890-
Args.resize(3);
1891-
// Execution scope
1892-
Args[0] = addInt32(map<Scope>(std::get<2>(Lit)));
1893-
// Memory scope
1894-
Args[1] = addInt32(map<Scope>(std::get<1>(Lit)));
1895-
// Memory semantics
1896-
// OpControlBarrierArriveINTEL -> Release,
1897-
// OpControlBarrierWaitINTEL -> Acquire
1898-
unsigned MemFenceFlag = std::get<0>(Lit);
1899-
OCLMemOrderKind MemOrder = OpCode == OpControlBarrierArriveINTEL
1900-
? OCLMO_release
1901-
: OCLMO_acquire;
1902-
Args[2] = addInt32(mapOCLMemSemanticToSPIRV(MemFenceFlag, MemOrder));
1903-
return getSPIRVFuncName(OpCode);
1904-
},
1905-
&Attrs);
1860+
// Map memory semantics as follows:
1861+
// OpControlBarrierArriveINTEL -> Release,
1862+
// OpControlBarrierWaitINTEL -> Acquire
1863+
unsigned MemFenceFlag = std::get<0>(Lit);
1864+
OCLMemOrderKind MemOrder =
1865+
OpCode == OpControlBarrierArriveINTEL ? OCLMO_release : OCLMO_acquire;
1866+
mutateCallInst(CI, OpCode)
1867+
.removeArgs(0, CI->arg_size())
1868+
// Execution scope
1869+
.appendArg(addInt32(map<Scope>(std::get<2>(Lit))))
1870+
// Memory scope
1871+
.appendArg(addInt32(map<Scope>(std::get<1>(Lit))))
1872+
// Memory semantics
1873+
.appendArg(addInt32(mapOCLMemSemanticToSPIRV(MemFenceFlag, MemOrder)));
19061874
}
19071875

19081876
void OCLToSPIRVBase::visitCallLdexp(CallInst *CI, StringRef MangledName,

lib/SPIRV/SPIRVInternal.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,8 @@ void move(std::vector<T> &V, size_t Begin, size_t End, size_t Target) {
550550
}
551551

552552
/// Find position of first pointer type value in a vector.
553-
inline size_t findFirstPtr(const std::vector<Value *> &Args) {
553+
template <typename Container>
554+
inline unsigned findFirstPtr(const Container &Args) {
554555
auto PtArg = std::find_if(Args.begin(), Args.end(), [](Value *V) {
555556
return V->getType()->isPointerTy();
556557
});

lib/SPIRV/SPIRVToOCL.h

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,17 @@ class SPIRVToOCLBase : public InstVisitor<SPIRVToOCLBase>,
203203

204204
/// Transform __spirv_OpAtomicCompareExchange and
205205
/// __spirv_OpAtomicCompareExchangeWeak
206-
virtual Instruction *visitCallSPIRVAtomicCmpExchg(CallInst *CI) = 0;
206+
virtual void visitCallSPIRVAtomicCmpExchg(CallInst *CI) = 0;
207207

208208
/// Transform __spirv_OpAtomicIIncrement/OpAtomicIDecrement to:
209209
/// - OCL2.0: atomic_fetch_add_explicit/atomic_fetch_sub_explicit
210210
/// - OCL1.2: atomic_inc/atomic_dec
211-
virtual Instruction *visitCallSPIRVAtomicIncDec(CallInst *CI, Op OC) = 0;
211+
virtual void visitCallSPIRVAtomicIncDec(CallInst *CI, Op OC) = 0;
212212

213213
/// Transform __spirv_Atomic* to atomic_*.
214214
/// __spirv_Atomic*(atomic_op, scope, sema, ops, ...) =>
215215
/// atomic_*(atomic_op, ops, ..., order(sema), map(scope))
216-
virtual Instruction *visitCallSPIRVAtomicBuiltin(CallInst *CI, Op OC) = 0;
216+
virtual void visitCallSPIRVAtomicBuiltin(CallInst *CI, Op OC) = 0;
217217

218218
/// Transform __spirv_MemoryBarrier to:
219219
/// - OCL2.0: atomic_work_item_fence.__spirv_MemoryBarrier(scope, sema) =>
@@ -246,7 +246,7 @@ class SPIRVToOCLBase : public InstVisitor<SPIRVToOCLBase>,
246246

247247
/// Transform __spirv_Opcode to ocl-version specific builtin name
248248
/// using separate maps for OpenCL 1.2 and OpenCL 2.0
249-
virtual Instruction *mutateAtomicName(CallInst *CI, Op OC) = 0;
249+
virtual void mutateAtomicName(CallInst *CI, Op OC) = 0;
250250

251251
// Transform FP atomic opcode to corresponding OpenCL function name
252252
virtual std::string mapFPAtomicName(Op OC) = 0;
@@ -324,35 +324,35 @@ class SPIRVToOCL12Base : public SPIRVToOCLBase {
324324

325325
/// Transform __spirv_OpAtomic functions. It firstly conduct generic
326326
/// mutations for all builtins and then mutate some of them seperately
327-
Instruction *visitCallSPIRVAtomicBuiltin(CallInst *CI, Op OC) override;
327+
void visitCallSPIRVAtomicBuiltin(CallInst *CI, Op OC) override;
328328

329329
/// Transform __spirv_OpAtomicIIncrement / OpAtomicIDecrement to
330330
/// atomic_inc / atomic_dec
331-
Instruction *visitCallSPIRVAtomicIncDec(CallInst *CI, Op OC) override;
331+
void visitCallSPIRVAtomicIncDec(CallInst *CI, Op OC) override;
332332

333333
/// Transform __spirv_OpAtomicUMin/SMin/UMax/SMax into
334334
/// atomic_min/atomic_max, as there is no distinction in OpenCL 1.2
335335
/// between signed and unsigned version of those functions
336-
Instruction *visitCallSPIRVAtomicUMinUMax(CallInst *CI, Op OC);
336+
void visitCallSPIRVAtomicUMinUMax(CallInst *CI, Op OC);
337337

338338
/// Transform __spirv_OpAtomicLoad to atomic_add(*ptr, 0)
339-
Instruction *visitCallSPIRVAtomicLoad(CallInst *CI);
339+
void visitCallSPIRVAtomicLoad(CallInst *CI);
340340

341341
/// Transform __spirv_OpAtomicStore to atomic_xchg(*ptr, value)
342-
Instruction *visitCallSPIRVAtomicStore(CallInst *CI);
342+
void visitCallSPIRVAtomicStore(CallInst *CI);
343343

344344
/// Transform __spirv_OpAtomicFlagClear to atomic_xchg(*ptr, 0)
345345
/// with ignoring the result
346-
Instruction *visitCallSPIRVAtomicFlagClear(CallInst *CI);
346+
void visitCallSPIRVAtomicFlagClear(CallInst *CI);
347347

348348
/// Transform __spirv_OpAtomicFlagTestAndTest to
349349
/// (bool)atomic_xchg(*ptr, 1)
350-
Instruction *visitCallSPIRVAtomicFlagTestAndSet(CallInst *CI);
350+
void visitCallSPIRVAtomicFlagTestAndSet(CallInst *CI);
351351

352352
/// Transform __spirv_OpAtomicCompareExchange/Weak into atomic_cmpxchg
353353
/// OpAtomicCompareExchangeWeak is not "weak" at all, but instead has
354354
/// the same semantics as OpAtomicCompareExchange.
355-
Instruction *visitCallSPIRVAtomicCmpExchg(CallInst *CI) override;
355+
void visitCallSPIRVAtomicCmpExchg(CallInst *CI) override;
356356

357357
/// Trigger assert, since OpenCL 1.2 doesn't support enqueue_kernel
358358
void visitCallSPIRVEnqueueKernel(CallInst *CI, Op OC) override;
@@ -361,7 +361,7 @@ class SPIRVToOCL12Base : public SPIRVToOCLBase {
361361
CallInst *mutateCommonAtomicArguments(CallInst *CI, Op OC) override;
362362

363363
/// Transform atomic builtin name into correct ocl-dependent name
364-
Instruction *mutateAtomicName(CallInst *CI, Op OC) override;
364+
void mutateAtomicName(CallInst *CI, Op OC) override;
365365

366366
// Transform FP atomic opcode to corresponding OpenCL function name
367367
std::string mapFPAtomicName(Op OC) override;
@@ -419,11 +419,11 @@ class SPIRVToOCL20Base : public SPIRVToOCLBase {
419419
/// Transform __spirv_Atomic* to atomic_*.
420420
/// __spirv_Atomic*(atomic_op, scope, sema, ops, ...) =>
421421
/// atomic_*(generic atomic_op, ops, ..., order(sema), map(scope))
422-
Instruction *visitCallSPIRVAtomicBuiltin(CallInst *CI, Op OC) override;
422+
void visitCallSPIRVAtomicBuiltin(CallInst *CI, Op OC) override;
423423

424424
/// Transform __spirv_OpAtomicIIncrement / OpAtomicIDecrement to
425425
/// atomic_fetch_add_explicit / atomic_fetch_sub_explicit
426-
Instruction *visitCallSPIRVAtomicIncDec(CallInst *CI, Op OC) override;
426+
void visitCallSPIRVAtomicIncDec(CallInst *CI, Op OC) override;
427427

428428
/// Transform __spirv_EnqueueKernel to __enqueue_kernel
429429
void visitCallSPIRVEnqueueKernel(CallInst *CI, Op OC) override;
@@ -432,7 +432,7 @@ class SPIRVToOCL20Base : public SPIRVToOCLBase {
432432
CallInst *mutateCommonAtomicArguments(CallInst *CI, Op OC) override;
433433

434434
/// Transform atomic builtin name into correct ocl-dependent name
435-
Instruction *mutateAtomicName(CallInst *CI, Op OC) override;
435+
void mutateAtomicName(CallInst *CI, Op OC) override;
436436

437437
// Transform FP atomic opcode to corresponding OpenCL function name
438438
std::string mapFPAtomicName(Op OC) override;
@@ -441,7 +441,7 @@ class SPIRVToOCL20Base : public SPIRVToOCLBase {
441441
/// atomic_compare_exchange_strong_explicit
442442
/// OpAtomicCompareExchangeWeak is not "weak" at all, but instead has
443443
/// the same semantics as OpAtomicCompareExchange.
444-
Instruction *visitCallSPIRVAtomicCmpExchg(CallInst *CI) override;
444+
void visitCallSPIRVAtomicCmpExchg(CallInst *CI) override;
445445
};
446446

447447
class SPIRVToOCL20Pass : public llvm::PassInfoMixin<SPIRVToOCL20Pass>,

0 commit comments

Comments
 (0)