Skip to content

Commit 6ac42c8

Browse files
committed
[GlobalISel] convergence control tokens and intrinsics
In the IR translator, convert the LLVM token type to LLT::token(), which is an alias for the s0 type. These show up as implicit uses on convergent operations. Differential Revision: https://reviews.llvm.org/D158147
1 parent 2e25926 commit 6ac42c8

File tree

14 files changed

+270
-45
lines changed

14 files changed

+270
-45
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ class CallLowering {
117117
/// vreg that the swifterror should be copied into after the call.
118118
Register SwiftErrorVReg;
119119

120+
/// Valid if the call is a controlled convergent operation.
121+
Register ConvergenceCtrlToken;
122+
120123
/// Original IR callsite corresponding to this call, if available.
121124
const CallBase *CB = nullptr;
122125

@@ -584,6 +587,7 @@ class CallLowering {
584587
bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
585588
ArrayRef<Register> ResRegs,
586589
ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
590+
Register ConvergenceCtrlToken,
587591
std::function<unsigned()> GetCalleeReg) const;
588592

589593
/// For targets which want to use big-endian can enable it with

llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,10 @@ class IRTranslator : public MachineFunctionPass {
579579
return false;
580580
}
581581

582+
bool translateConvergenceControlIntrinsic(const CallInst &CI,
583+
Intrinsic::ID ID,
584+
MachineIRBuilder &MIRBuilder);
585+
582586
/// @}
583587

584588
// Builder for machine instruction a la IRBuilder.
@@ -697,6 +701,23 @@ class IRTranslator : public MachineFunctionPass {
697701
return Regs[0];
698702
}
699703

704+
Register getOrCreateConvergenceTokenVReg(const Value &Token) {
705+
assert(Token.getType()->isTokenTy());
706+
auto &Regs = *VMap.getVRegs(Token);
707+
if (!Regs.empty()) {
708+
assert(Regs.size() == 1 &&
709+
"Expected a single register for convergence tokens.");
710+
return Regs[0];
711+
}
712+
713+
auto Reg = MRI->createGenericVirtualRegister(LLT::token());
714+
Regs.push_back(Reg);
715+
auto &Offsets = *VMap.getOffsets(Token);
716+
if (Offsets.empty())
717+
Offsets.push_back(0);
718+
return Reg;
719+
}
720+
700721
/// Allocate some vregs and offsets in the VMap. Then populate just the
701722
/// offsets while leaving the vregs empty.
702723
ValueToVRegInfo::VRegListT &allocateVRegs(const Value &Val);

llvm/include/llvm/CodeGenTypes/LowLevelType.h

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ class LLT {
4545
/*AddressSpace=*/0};
4646
}
4747

48+
/// Get a low-level token; just a scalar with zero bits (or no size).
49+
static constexpr LLT token() {
50+
return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true,
51+
ElementCount::getFixed(0), /*SizeInBits=*/0,
52+
/*AddressSpace=*/0};
53+
}
54+
4855
/// Get a low-level pointer in the given address space.
4956
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) {
5057
assert(SizeInBits > 0 && "invalid pointer size");
@@ -134,17 +141,14 @@ class LLT {
134141

135142
explicit LLT(MVT VT);
136143

137-
constexpr bool isValid() const { return IsScalar || IsPointer || IsVector; }
138-
144+
constexpr bool isValid() const { return IsScalar || RawData != 0; }
139145
constexpr bool isScalar() const { return IsScalar; }
140-
141-
constexpr bool isPointer() const { return IsPointer && !IsVector; }
142-
143-
constexpr bool isPointerVector() const { return IsPointer && IsVector; }
144-
145-
constexpr bool isPointerOrPointerVector() const { return IsPointer; }
146-
147-
constexpr bool isVector() const { return IsVector; }
146+
constexpr bool isVector() const { return isValid() && IsVector; }
147+
constexpr bool isPointer() const {
148+
return isValid() && IsPointer && !IsVector;
149+
}
150+
constexpr bool isPointerVector() const { return IsPointer && isVector(); }
151+
constexpr bool isPointerOrPointerVector() const { return IsPointer && isValid(); }
148152

149153
/// Returns the number of elements in a vector LLT. Must only be called on
150154
/// vector types.
@@ -314,6 +318,28 @@ class LLT {
314318
/// described in static const *Field variables. Each of these variables
315319
/// is a 2-element array, with the first element describing the bitfield size
316320
/// and the second element describing the bitfield offset.
321+
///
322+
/// +--------+---------+--------+----------+----------------------+
323+
/// |isScalar|isPointer|isVector| RawData |Notes |
324+
/// +--------+---------+--------+----------+----------------------+
325+
/// | 0 | 0 | 0 | 0 |Invalid |
326+
/// +--------+---------+--------+----------+----------------------+
327+
/// | 0 | 0 | 1 | 0 |Tombstone Key |
328+
/// +--------+---------+--------+----------+----------------------+
329+
/// | 0 | 1 | 0 | 0 |Empty Key |
330+
/// +--------+---------+--------+----------+----------------------+
331+
/// | 1 | 0 | 0 | 0 |Token |
332+
/// +--------+---------+--------+----------+----------------------+
333+
/// | 1 | 0 | 0 | non-zero |Scalar |
334+
/// +--------+---------+--------+----------+----------------------+
335+
/// | 0 | 1 | 0 | non-zero |Pointer |
336+
/// +--------+---------+--------+----------+----------------------+
337+
/// | 0 | 0 | 1 | non-zero |Vector of non-pointer |
338+
/// +--------+---------+--------+----------+----------------------+
339+
/// | 0 | 1 | 1 | non-zero |Vector of pointer |
340+
/// +--------+---------+--------+----------+----------------------+
341+
///
342+
/// Everything else is reserved.
317343
typedef int BitFieldInfo[2];
318344
///
319345
/// This is how the bitfields are packed per Kind:

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/CodeGen/MachineRegisterInfo.h"
2222
#include "llvm/CodeGen/TargetLowering.h"
2323
#include "llvm/IR/DataLayout.h"
24+
#include "llvm/IR/IntrinsicInst.h"
2425
#include "llvm/IR/LLVMContext.h"
2526
#include "llvm/IR/Module.h"
2627
#include "llvm/Target/TargetMachine.h"
@@ -87,10 +88,20 @@ void CallLowering::addArgFlagsFromAttributes(ISD::ArgFlagsTy &Flags,
8788
});
8889
}
8990

91+
static bool hasConvergenceEntryToken(const CallBase &CB) {
92+
auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl);
93+
if (!Bundle)
94+
return true;
95+
auto *Token = Bundle->Inputs[0].get();
96+
auto *Def = cast<IntrinsicInst>(Token);
97+
return Def->getIntrinsicID() == Intrinsic::experimental_convergence_entry;
98+
}
99+
90100
bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
91101
ArrayRef<Register> ResRegs,
92102
ArrayRef<ArrayRef<Register>> ArgRegs,
93103
Register SwiftErrorVReg,
104+
Register ConvergenceCtrlToken,
94105
std::function<unsigned()> GetCalleeReg) const {
95106
CallLoweringInfo Info;
96107
const DataLayout &DL = MIRBuilder.getDataLayout();
@@ -121,6 +132,8 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
121132
CanBeTailCalled = false;
122133
}
123134

135+
if (!hasConvergenceEntryToken(CB))
136+
CanBeTailCalled = false;
124137

125138
// First step is to marshall all the function's parameters into the correct
126139
// physregs and memory locations. Gather the sequence of argument types that
@@ -187,6 +200,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
187200
Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
188201
Info.CallConv = CallConv;
189202
Info.SwiftErrorVReg = SwiftErrorVReg;
203+
Info.ConvergenceCtrlToken = ConvergenceCtrlToken;
190204
Info.IsMustTailCall = CB.isMustTailCall();
191205
Info.IsTailCall = CanBeTailCalled;
192206
Info.IsVarArg = IsVarArg;

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ ArrayRef<Register> IRTranslator::getOrCreateVRegs(const Value &Val) {
213213
auto *VRegs = VMap.getVRegs(Val);
214214
auto *Offsets = VMap.getOffsets(Val);
215215

216-
assert(Val.getType()->isSized() &&
217-
"Don't know how to create an empty vreg");
216+
if (!Val.getType()->isTokenTy())
217+
assert(Val.getType()->isSized() &&
218+
"Don't know how to create an empty vreg");
218219

219220
SmallVector<LLT, 4> SplitTys;
220221
computeValueLLTs(*DL, *Val.getType(), SplitTys,
@@ -2036,6 +2037,37 @@ bool IRTranslator::translateIfEntryValueArgument(bool isDeclare, Value *Val,
20362037
return true;
20372038
}
20382039

2040+
static unsigned getConvOpcode(Intrinsic::ID ID) {
2041+
switch (ID) {
2042+
default:
2043+
llvm_unreachable("Unexpected intrinsic");
2044+
return 0;
2045+
case Intrinsic::experimental_convergence_anchor:
2046+
return TargetOpcode::CONVERGENCECTRL_ANCHOR;
2047+
case Intrinsic::experimental_convergence_entry:
2048+
return TargetOpcode::CONVERGENCECTRL_ENTRY;
2049+
case Intrinsic::experimental_convergence_loop:
2050+
return TargetOpcode::CONVERGENCECTRL_LOOP;
2051+
}
2052+
}
2053+
2054+
bool IRTranslator::translateConvergenceControlIntrinsic(
2055+
const CallInst &CI, Intrinsic::ID ID, MachineIRBuilder &MIRBuilder) {
2056+
MachineInstrBuilder MIB = MIRBuilder.buildInstr(getConvOpcode(ID));
2057+
Register OutputReg = getOrCreateConvergenceTokenVReg(CI);
2058+
MIB.addDef(OutputReg);
2059+
2060+
if (ID == Intrinsic::experimental_convergence_loop) {
2061+
auto Bundle = CI.getOperandBundle(LLVMContext::OB_convergencectrl);
2062+
assert(Bundle && "Expected a convergence control token.");
2063+
Register InputReg =
2064+
getOrCreateConvergenceTokenVReg(*Bundle->Inputs[0].get());
2065+
MIB.addUse(InputReg);
2066+
}
2067+
2068+
return true;
2069+
}
2070+
20392071
bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
20402072
MachineIRBuilder &MIRBuilder) {
20412073
if (auto *MI = dyn_cast<AnyMemIntrinsic>(&CI)) {
@@ -2477,7 +2509,10 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
24772509
#include "llvm/IR/ConstrainedOps.def"
24782510
return translateConstrainedFPIntrinsic(cast<ConstrainedFPIntrinsic>(CI),
24792511
MIRBuilder);
2480-
2512+
case Intrinsic::experimental_convergence_anchor:
2513+
case Intrinsic::experimental_convergence_entry:
2514+
case Intrinsic::experimental_convergence_loop:
2515+
return translateConvergenceControlIntrinsic(CI, ID, MIRBuilder);
24812516
}
24822517
return false;
24832518
}
@@ -2528,12 +2563,18 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
25282563
}
25292564
}
25302565

2566+
Register ConvergenceCtrlToken = 0;
2567+
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl)) {
2568+
const auto &Token = *Bundle->Inputs[0].get();
2569+
ConvergenceCtrlToken = getOrCreateConvergenceTokenVReg(Token);
2570+
}
2571+
25312572
// We don't set HasCalls on MFI here yet because call lowering may decide to
25322573
// optimize into tail calls. Instead, we defer that to selection where a final
25332574
// scan is done to check if any instructions are calls.
2534-
bool Success =
2535-
CLI->lowerCall(MIRBuilder, CB, Res, Args, SwiftErrorVReg,
2536-
[&]() { return getOrCreateVReg(*CB.getCalledOperand()); });
2575+
bool Success = CLI->lowerCall(
2576+
MIRBuilder, CB, Res, Args, SwiftErrorVReg, ConvergenceCtrlToken,
2577+
[&]() { return getOrCreateVReg(*CB.getCalledOperand()); });
25372578

25382579
// Check if we just inserted a tail call.
25392580
if (Success) {
@@ -2647,6 +2688,14 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
26472688
MF->getMachineMemOperand(MPI, Info.flags, MemTy, Alignment, CI.getAAMetadata()));
26482689
}
26492690

2691+
if (CI.isConvergent()) {
2692+
if (auto Bundle = CI.getOperandBundle(LLVMContext::OB_convergencectrl)) {
2693+
auto *Token = Bundle->Inputs[0].get();
2694+
Register TokenReg = getOrCreateVReg(*Token);
2695+
MIB.addUse(TokenReg, RegState::Implicit);
2696+
}
2697+
}
2698+
26502699
return true;
26512700
}
26522701

llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,14 @@ bool InlineAsmLowering::lowerInlineAsm(
538538
}
539539
}
540540

541+
if (auto Bundle = Call.getOperandBundle(LLVMContext::OB_convergencectrl)) {
542+
auto *Token = Bundle->Inputs[0].get();
543+
ArrayRef<Register> SourceRegs = GetOrCreateVRegs(*Token);
544+
assert(SourceRegs.size() == 1 &&
545+
"Expected the control token to fit into a single virtual register");
546+
Inst.addUse(SourceRegs[0], RegState::Implicit);
547+
}
548+
541549
if (const MDNode *SrcLoc = Call.getMetadata("srcloc"))
542550
Inst.addMetadata(SrcLoc);
543551

llvm/lib/CodeGen/LowLevelTypeUtils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) {
3939
return LLT::scalar(SizeInBits);
4040
}
4141

42+
if (Ty.isTokenTy()) {
43+
return LLT::token();
44+
}
45+
4246
return LLT();
4347
}
4448

llvm/lib/CodeGen/MIRParser/MIParser.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,10 +1919,13 @@ bool MIParser::parseLowLevelType(StringRef::iterator Loc, LLT &Ty) {
19191919

19201920
if (Token.range().front() == 's') {
19211921
auto ScalarSize = APSInt(Token.range().drop_front()).getZExtValue();
1922-
if (!verifyScalarSize(ScalarSize))
1923-
return error("invalid size for scalar type");
1924-
1925-
Ty = LLT::scalar(ScalarSize);
1922+
if (ScalarSize) {
1923+
if (!verifyScalarSize(ScalarSize))
1924+
return error("invalid size for scalar type");
1925+
Ty = LLT::scalar(ScalarSize);
1926+
} else {
1927+
Ty = LLT::token();
1928+
}
19261929
lex();
19271930
return false;
19281931
} else if (Token.range().front() == 'p') {
@@ -1980,7 +1983,7 @@ bool MIParser::parseLowLevelType(StringRef::iterator Loc, LLT &Ty) {
19801983
if (Token.range().front() == 's') {
19811984
auto ScalarSize = APSInt(Token.range().drop_front()).getZExtValue();
19821985
if (!verifyScalarSize(ScalarSize))
1983-
return error("invalid size for scalar type");
1986+
return error("invalid size for scalar element in vector");
19841987
Ty = LLT::scalar(ScalarSize);
19851988
} else if (Token.range().front() == 'p') {
19861989
const DataLayout &DL = MF.getDataLayout();

llvm/lib/IR/ConvergenceVerifier.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@ GenericConvergenceVerifier<SSAContext>::findAndCheckConvergenceTokenUsed(
7575

7676
template <>
7777
bool GenericConvergenceVerifier<SSAContext>::isInsideConvergentFunction(
78-
const InstructionT &I) {
78+
const Instruction &I) {
7979
auto *F = I.getFunction();
8080
return F->isConvergent();
8181
}
8282

8383
template <>
8484
bool GenericConvergenceVerifier<SSAContext>::isConvergent(
85-
const InstructionT &I) {
85+
const Instruction &I) {
8686
if (auto *CB = dyn_cast<CallBase>(&I)) {
8787
return CB->isConvergent();
8888
}

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,9 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
14831483

14841484
const SIMachineFunctionInfo *MFI = MF.getInfo<SIMachineFunctionInfo>();
14851485

1486+
if (Info.ConvergenceCtrlToken) {
1487+
MIB.addUse(Info.ConvergenceCtrlToken, RegState::Implicit);
1488+
}
14861489
handleImplicitCallArguments(MIRBuilder, MIB, ST, *MFI, Info.CallConv,
14871490
ImplicitArgRegs);
14881491

0 commit comments

Comments
 (0)