Skip to content

Commit 1b156e0

Browse files
committed
[GlobalISel] convergence control tokens and intrinsics
Support for lowering convergence control to GMIR: - Introduce new G_CONVERGENCECTRL_* opcodes. - In the IR translator, convert the LLVM token type to LLT::token(), which is an alias for the s0 type. These show up as implicit uses on convergent operations. - In the machine verifier, enforce the static rules of convergence control. Note that the lowering of the new GMIR opcodes is entirely target-specific. It is generally expected that the backend will use convergence control to change the CFG of each function, and then discard the tokens. This is currently a work in progress for AMDGPU. Differential Revision: https://reviews.llvm.org/D158147
1 parent 24161bc commit 1b156e0

26 files changed

+602
-51
lines changed

llvm/include/llvm/ADT/GenericConvergenceVerifier.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ template <typename ContextT> class GenericConvergenceVerifier {
6060
NoConvergence
6161
} ConvergenceKind = NoConvergence;
6262

63+
/// The control token operation performed by a convergence control Intrinsic in LLVM IR,
64+
/// or by a G_CONVERGENCECTRL* instruction in GMIR.
65+
enum ConvOpKind { CONV_ANCHOR, CONV_ENTRY, CONV_LOOP, CONV_NONE };
66+
6367
// Cache token uses found so far. Note that we track the unique definitions
6468
// and not the token values.
6569
DenseMap<const InstructionT *, const InstructionT *> Tokens;
@@ -68,6 +72,7 @@ template <typename ContextT> class GenericConvergenceVerifier {
6872

6973
static bool isInsideConvergentFunction(const InstructionT &I);
7074
static bool isConvergent(const InstructionT &I);
75+
static ConvOpKind getConvOp(const InstructionT &I);
7176
const InstructionT *findAndCheckConvergenceTokenUsed(const InstructionT &I);
7277

7378
void reportFailure(const Twine &Message, ArrayRef<Printable> Values);

llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ class CallLowering {
117117
/// vreg that the swifterror should be copied into after the call.
118118
Register SwiftErrorVReg;
119119

120+
/// Valid if the call is a controlled convergent operation.
121+
Register ConvergenceCtrlToken;
122+
120123
/// Original IR callsite corresponding to this call, if available.
121124
const CallBase *CB = nullptr;
122125

@@ -583,6 +586,7 @@ class CallLowering {
583586
bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
584587
ArrayRef<Register> ResRegs,
585588
ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
589+
Register ConvergenceCtrlToken,
586590
std::function<unsigned()> GetCalleeReg) const;
587591

588592
/// For targets which want to use big-endian can enable it with

llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,10 @@ class IRTranslator : public MachineFunctionPass {
554554
return false;
555555
}
556556

557+
bool translateConvergenceControlIntrinsic(const CallInst &CI,
558+
Intrinsic::ID ID,
559+
MachineIRBuilder &MIRBuilder);
560+
557561
/// @}
558562

559563
// Builder for machine instruction a la IRBuilder.
@@ -671,6 +675,23 @@ class IRTranslator : public MachineFunctionPass {
671675
return Regs[0];
672676
}
673677

678+
Register getOrCreateConvergenceTokenVReg(const Value &Token) {
679+
assert(Token.getType()->isTokenTy());
680+
auto &Regs = *VMap.getVRegs(Token);
681+
if (!Regs.empty()) {
682+
assert(Regs.size() == 1 &&
683+
"Expected a single register for convergence tokens.");
684+
return Regs[0];
685+
}
686+
687+
auto Reg = MRI->createGenericVirtualRegister(LLT::token());
688+
Regs.push_back(Reg);
689+
auto &Offsets = *VMap.getOffsets(Token);
690+
if (Offsets.empty())
691+
Offsets.push_back(0);
692+
return Reg;
693+
}
694+
674695
/// Allocate some vregs and offsets in the VMap. Then populate just the
675696
/// offsets while leaving the vregs empty.
676697
ValueToVRegInfo::VRegListT &allocateVRegs(const Value &Val);

llvm/include/llvm/CodeGen/LowLevelType.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ class LLT {
4545
/*AddressSpace=*/0};
4646
}
4747

48+
/// Get a low-level token; just a scalar with zero bits (or no size).
49+
static constexpr LLT token() {
50+
return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true,
51+
ElementCount::getFixed(0), /*SizeInBits=*/0,
52+
/*AddressSpace=*/0};
53+
}
54+
4855
/// Get a low-level pointer in the given address space.
4956
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) {
5057
assert(SizeInBits > 0 && "invalid pointer size");
@@ -304,6 +311,28 @@ class LLT {
304311
/// described in static const *Field variables. Each of these variables
305312
/// is a 2-element array, with the first element describing the bitfield size
306313
/// and the second element describing the bitfield offset.
314+
///
315+
/// +--------+---------+--------+----------+----------------------+
316+
/// |isScalar|isPointer|isVector| RawData |Notes |
317+
/// +--------+---------+--------+----------+----------------------+
318+
/// | 0 | 0 | 0 | 0 |Invalid |
319+
/// +--------+---------+--------+----------+----------------------+
320+
/// | 0 | 0 | 1 | 0 |Tombstone Key |
321+
/// +--------+---------+--------+----------+----------------------+
322+
/// | 0 | 1 | 0 | 0 |Empty Key |
323+
/// +--------+---------+--------+----------+----------------------+
324+
/// | 1 | 0 | 0 | 0 |Token |
325+
/// +--------+---------+--------+----------+----------------------+
326+
/// | 1 | 0 | 0 | non-zero |Scalar |
327+
/// +--------+---------+--------+----------+----------------------+
328+
/// | 0 | 1 | 0 | non-zero |Pointer |
329+
/// +--------+---------+--------+----------+----------------------+
330+
/// | 0 | 0 | 1 | non-zero |Vector of non-pointer |
331+
/// +--------+---------+--------+----------+----------------------+
332+
/// | 0 | 1 | 1 | non-zero |Vector of pointer |
333+
/// +--------+---------+--------+----------+----------------------+
334+
///
335+
/// Everything else is reserved.
307336
typedef int BitFieldInfo[2];
308337
///
309338
/// This is how the bitfields are packed per Kind:
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- MachineConvergenceVerifier.h - Verify convergenctrl ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
/// \file
9+
///
10+
/// This file declares the GMIR IR specialization of the
11+
/// GenericConvergenceVerifier template.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
16+
#define LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
17+
18+
#include "llvm/ADT/GenericConvergenceVerifier.h"
19+
#include "llvm/CodeGen/MachineSSAContext.h"
20+
21+
namespace llvm {
22+
23+
using MachineConvergenceVerifier =
24+
GenericConvergenceVerifier<MachineSSAContext>;
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H

llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,6 @@ using namespace llvm;
4949
} \
5050
} while (false)
5151

52-
static bool isConvergenceControlIntrinsic(unsigned IntrinsicID) {
53-
switch (IntrinsicID) {
54-
default:
55-
return false;
56-
case Intrinsic::experimental_convergence_anchor:
57-
case Intrinsic::experimental_convergence_entry:
58-
case Intrinsic::experimental_convergence_loop:
59-
return true;
60-
}
61-
}
62-
6352
namespace llvm {
6453
template <class ContextT> void GenericConvergenceVerifier<ContextT>::clear() {
6554
Tokens.clear();
@@ -74,12 +63,11 @@ void GenericConvergenceVerifier<ContextT>::visit(const BlockT &BB) {
7463

7564
template <class ContextT>
7665
void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
77-
auto ID = ContextT::getIntrinsicID(I);
66+
auto ConvOp = getConvOp(I);
7867
auto *TokenDef = findAndCheckConvergenceTokenUsed(I);
79-
bool IsCtrlIntrinsic = true;
8068

81-
switch (ID) {
82-
case Intrinsic::experimental_convergence_entry:
69+
switch (ConvOp) {
70+
case CONV_ENTRY:
8371
Check(isInsideConvergentFunction(I),
8472
"Entry intrinsic can occur only in a convergent function.",
8573
{Context.print(&I)});
@@ -91,13 +79,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
9179
"same basic block.",
9280
{Context.print(&I)});
9381
LLVM_FALLTHROUGH;
94-
case Intrinsic::experimental_convergence_anchor:
82+
case CONV_ANCHOR:
9583
Check(!TokenDef,
9684
"Entry or anchor intrinsic cannot have a convergencectrl token "
9785
"operand.",
9886
{Context.print(&I)});
9987
break;
100-
case Intrinsic::experimental_convergence_loop:
88+
case CONV_LOOP:
10189
Check(TokenDef, "Loop intrinsic must have a convergencectrl token operand.",
10290
{Context.print(&I)});
10391
Check(!SeenFirstConvOp,
@@ -106,14 +94,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
10694
{Context.print(&I)});
10795
break;
10896
default:
109-
IsCtrlIntrinsic = false;
11097
break;
11198
}
11299

113100
if (isConvergent(I))
114101
SeenFirstConvOp = true;
115102

116-
if (TokenDef || IsCtrlIntrinsic) {
103+
if (TokenDef || ConvOp != CONV_NONE) {
117104
Check(isConvergent(I),
118105
"Convergence control token can only be used in a convergent call.",
119106
{Context.print(&I)});
@@ -174,8 +161,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
174161
return;
175162
}
176163

177-
Check(ContextT::getIntrinsicID(*User) ==
178-
Intrinsic::experimental_convergence_loop,
164+
Check(getConvOp(*User) == CONV_LOOP,
179165
"Convergence token used by an instruction other than "
180166
"llvm.experimental.convergence.loop in a cycle that does "
181167
"not contain the token's definition.",
@@ -212,7 +198,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
212198
for (auto &I : *BB) {
213199
if (auto *Token = Tokens.lookup(&I))
214200
checkToken(Token, &I, LiveTokens);
215-
if (isConvergenceControlIntrinsic(ContextT::getIntrinsicID(I)))
201+
if (getConvOp(I) != CONV_NONE)
216202
LiveTokens.push_back(&I);
217203
}
218204

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,10 @@ HANDLE_TARGET_OPCODE(G_INTRINSIC_CONVERGENT)
436436
/// Generic intrinsic use (with side effects).
437437
HANDLE_TARGET_OPCODE(G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS)
438438

439+
HANDLE_TARGET_OPCODE(G_CONVERGENCECTRL_ENTRY)
440+
HANDLE_TARGET_OPCODE(G_CONVERGENCECTRL_ANCHOR)
441+
HANDLE_TARGET_OPCODE(G_CONVERGENCECTRL_LOOP)
442+
439443
/// Generic extension allowing rubbish in high bits.
440444
HANDLE_TARGET_OPCODE(G_ANYEXT)
441445

llvm/include/llvm/Target/GenericOpcodes.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,34 @@ def G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS : GenericInstruction {
12761276
let isConvergent = true;
12771277
}
12781278

1279+
//------------------------------------------------------------------------------
1280+
// Convergence control operations.
1281+
//------------------------------------------------------------------------------
1282+
1283+
// Capture the set of threads that are converged on entry to a function.
1284+
def G_CONVERGENCECTRL_ENTRY : GenericInstruction {
1285+
let InOperandList = (ins);
1286+
let OutOperandList = (outs type0:$dst);
1287+
let isConvergent = true;
1288+
let hasSideEffects = false;
1289+
}
1290+
1291+
// Capture an implementation-defined subset of converged threads.
1292+
def G_CONVERGENCECTRL_ANCHOR : GenericInstruction {
1293+
let InOperandList = (ins);
1294+
let OutOperandList = (outs type0:$dst);
1295+
let isConvergent = true;
1296+
let hasSideEffects = false;
1297+
}
1298+
1299+
// Capture the convergence of threads in a cycle.
1300+
def G_CONVERGENCECTRL_LOOP : GenericInstruction {
1301+
let InOperandList = (ins type0:$src);
1302+
let OutOperandList = (outs type0:$dst);
1303+
let isConvergent = true;
1304+
let hasSideEffects = false;
1305+
}
1306+
12791307
//------------------------------------------------------------------------------
12801308
// Branches.
12811309
//------------------------------------------------------------------------------

llvm/lib/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ add_llvm_component_library(LLVMCodeGen
121121
MachineBranchProbabilityInfo.cpp
122122
MachineCFGPrinter.cpp
123123
MachineCombiner.cpp
124+
MachineConvergenceVerifier.cpp
124125
MachineCopyPropagation.cpp
125126
MachineCSE.cpp
126127
MachineCheckDebugify.cpp

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/CodeGen/MachineRegisterInfo.h"
2222
#include "llvm/CodeGen/TargetLowering.h"
2323
#include "llvm/IR/DataLayout.h"
24+
#include "llvm/IR/IntrinsicInst.h"
2425
#include "llvm/IR/LLVMContext.h"
2526
#include "llvm/IR/Module.h"
2627
#include "llvm/Target/TargetMachine.h"
@@ -87,10 +88,20 @@ void CallLowering::addArgFlagsFromAttributes(ISD::ArgFlagsTy &Flags,
8788
});
8889
}
8990

91+
static bool hasConvergenceEntryToken(const CallBase &CB) {
92+
auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl);
93+
if (!Bundle)
94+
return true;
95+
auto *Token = Bundle->Inputs[0].get();
96+
auto *Def = cast<IntrinsicInst>(Token);
97+
return Def->getIntrinsicID() == Intrinsic::experimental_convergence_entry;
98+
}
99+
90100
bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
91101
ArrayRef<Register> ResRegs,
92102
ArrayRef<ArrayRef<Register>> ArgRegs,
93103
Register SwiftErrorVReg,
104+
Register ConvergenceCtrlToken,
94105
std::function<unsigned()> GetCalleeReg) const {
95106
CallLoweringInfo Info;
96107
const DataLayout &DL = MIRBuilder.getDataLayout();
@@ -121,6 +132,8 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
121132
CanBeTailCalled = false;
122133
}
123134

135+
if (!hasConvergenceEntryToken(CB))
136+
CanBeTailCalled = false;
124137

125138
// First step is to marshall all the function's parameters into the correct
126139
// physregs and memory locations. Gather the sequence of argument types that
@@ -176,6 +189,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
176189
Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
177190
Info.CallConv = CallConv;
178191
Info.SwiftErrorVReg = SwiftErrorVReg;
192+
Info.ConvergenceCtrlToken = ConvergenceCtrlToken;
179193
Info.IsMustTailCall = CB.isMustTailCall();
180194
Info.IsTailCall = CanBeTailCalled;
181195
Info.IsVarArg = IsVarArg;

0 commit comments

Comments
 (0)