Skip to content

Commit abcdfb4

Browse files
committed
Implement convergence control in MIR using SelectionDAG
LLVM function calls carry convergence control tokens as operand bundles, where the tokens themselves are produced by convergence control intrinsics. This patch implements convergence control tokens in MIR as follows: 1. Introduce target-independent ISD opcodes and MIR opcodes for convergence control intrinsics. 2. Model token values as untyped virtual registers in MIR. The change also introduces an additional ISD opcode CONVERGENCECTRL_GLUE and a corresponding machine opcode with the same spelling. This glues the convergence control token to SDNodes that represent calls to intrinsics. The glued token is later translated to an implicit argument in the MIR. The lowering of calls to user-defined functions is target-specific. On AMDGPU, the convergence control operand bundle at a non-intrinsic call is translated to an explicit argument to the SI_CALL_ISEL instruction. Post-selection adjustment converts this explicit argument to an implicit argument on the SI_CALL instruction.
1 parent 96c5b8c commit abcdfb4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+831
-162
lines changed

llvm/include/llvm/ADT/GenericConvergenceVerifier.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ template <typename ContextT> class GenericConvergenceVerifier {
3232

3333
void initialize(raw_ostream *OS,
3434
function_ref<void(const Twine &Message)> FailureCB,
35-
const FunctionT &F) {
35+
const FunctionT &F, bool _IsSSA) {
3636
clear();
3737
this->OS = OS;
3838
this->FailureCB = FailureCB;
3939
Context = ContextT(&F);
40+
IsSSA = _IsSSA;
4041
}
4142

4243
void clear();
@@ -52,6 +53,7 @@ template <typename ContextT> class GenericConvergenceVerifier {
5253
DominatorTreeT *DT;
5354
CycleInfoT CI;
5455
ContextT Context;
56+
bool IsSSA;
5557

5658
/// Whether the current function has convergencectrl operand bundles.
5759
enum {
@@ -60,6 +62,10 @@ template <typename ContextT> class GenericConvergenceVerifier {
6062
NoConvergence
6163
} ConvergenceKind = NoConvergence;
6264

65+
/// The control token operation performed by a convergence control Intrinsic
66+
/// in LLVM IR, or by a CONVERGENCECTRL* instruction in MIR
67+
enum ConvOpKind { CONV_ANCHOR, CONV_ENTRY, CONV_LOOP, CONV_NONE };
68+
6369
// Cache token uses found so far. Note that we track the unique definitions
6470
// and not the token values.
6571
DenseMap<const InstructionT *, const InstructionT *> Tokens;
@@ -68,6 +74,7 @@ template <typename ContextT> class GenericConvergenceVerifier {
6874

6975
static bool isInsideConvergentFunction(const InstructionT &I);
7076
static bool isConvergent(const InstructionT &I);
77+
static ConvOpKind getConvOp(const InstructionT &I);
7178
const InstructionT *findAndCheckConvergenceTokenUsed(const InstructionT &I);
7279

7380
void reportFailure(const Twine &Message, ArrayRef<Printable> Values);

llvm/include/llvm/CodeGen/FunctionLoweringInfo.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,7 @@ class FunctionLoweringInfo {
215215

216216
Register CreateRegs(Type *Ty, bool isDivergent = false);
217217

218-
Register InitializeRegForValue(const Value *V) {
219-
// Tokens never live in vregs.
220-
if (V->getType()->isTokenTy())
221-
return 0;
222-
Register &R = ValueMap[V];
223-
assert(R == 0 && "Already initialized this value register!");
224-
assert(VirtReg2Value.empty());
225-
return R = CreateRegs(V);
226-
}
218+
Register InitializeRegForValue(const Value *V);
227219

228220
/// GetLiveOutRegInfo - Gets LiveOutInfo for a register, returning NULL if the
229221
/// register is a PHI destination and the PHI's LiveOutInfo is not valid.

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,15 @@ enum NodeType {
13841384
#define BEGIN_REGISTER_VP_SDNODE(VPSDID, ...) VPSDID,
13851385
#include "llvm/IR/VPIntrinsics.def"
13861386

1387+
// The `llvm.experimental.convergence.*` intrinsics.
1388+
CONVERGENCECTRL_ANCHOR,
1389+
CONVERGENCECTRL_ENTRY,
1390+
CONVERGENCECTRL_LOOP,
1391+
// This does not correspond to any convergence control intrinsic. It used to
1392+
// glue a convergence control token to a convergent operation in the DAG,
1393+
// which is later translated to an implicit use in the MIR.
1394+
CONVERGENCECTRL_GLUE,
1395+
13871396
/// BUILTIN_OP_END - This must be the last enum value in this list.
13881397
/// The target-specific pre-isel opcode values start here.
13891398
BUILTIN_OP_END
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- MachineConvergenceVerifier.h - Verify convergenctrl ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
/// \file
9+
///
10+
/// This file declares the MIR specialization of the GenericConvergenceVerifier
11+
/// template.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
16+
#define LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
17+
18+
#include "llvm/ADT/GenericConvergenceVerifier.h"
19+
#include "llvm/CodeGen/MachineSSAContext.h"
20+
21+
namespace llvm {
22+
23+
using MachineConvergenceVerifier =
24+
GenericConvergenceVerifier<MachineSSAContext>;
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H

llvm/include/llvm/CodeGen/SelectionDAGISel.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,10 @@ class SelectionDAGISel : public MachineFunctionPass {
459459
void Select_ARITH_FENCE(SDNode *N);
460460
void Select_MEMBARRIER(SDNode *N);
461461

462+
void Select_CONVERGENCECTRL_ANCHOR(SDNode *N);
463+
void Select_CONVERGENCECTRL_ENTRY(SDNode *N);
464+
void Select_CONVERGENCECTRL_LOOP(SDNode *N);
465+
462466
void pushStackMapLiveVariable(SmallVectorImpl<SDValue> &Ops, SDValue Operand,
463467
SDLoc DL);
464468
void Select_STACKMAP(SDNode *N);

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4401,6 +4401,7 @@ class TargetLowering : public TargetLoweringBase {
44014401
SmallVector<ISD::InputArg, 32> Ins;
44024402
SmallVector<SDValue, 4> InVals;
44034403
const ConstantInt *CFIType = nullptr;
4404+
SDValue ConvergenceControlToken;
44044405

44054406
CallLoweringInfo(SelectionDAG &DAG)
44064407
: RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
@@ -4534,6 +4535,11 @@ class TargetLowering : public TargetLoweringBase {
45344535
return *this;
45354536
}
45364537

4538+
CallLoweringInfo &setConvergenceControlToken(SDValue Token) {
4539+
ConvergenceControlToken = Token;
4540+
return *this;
4541+
}
4542+
45374543
ArgListTy &getArgs() {
45384544
return Args;
45394545
}

llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ template <class ContextT> void GenericConvergenceVerifier<ContextT>::clear() {
5252
Tokens.clear();
5353
CI.clear();
5454
ConvergenceKind = NoConvergence;
55+
IsSSA = false;
5556
}
5657

5758
template <class ContextT>
@@ -61,12 +62,16 @@ void GenericConvergenceVerifier<ContextT>::visit(const BlockT &BB) {
6162

6263
template <class ContextT>
6364
void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
64-
auto ID = ContextT::getIntrinsicID(I);
65+
ConvOpKind ConvOp = getConvOp(I);
66+
if (!IsSSA) {
67+
Check(ConvOp == CONV_NONE, "Convergence control requires SSA.",
68+
{Context.print(&I)});
69+
return;
70+
}
6571
auto *TokenDef = findAndCheckConvergenceTokenUsed(I);
66-
bool IsCtrlIntrinsic = true;
6772

68-
switch (ID) {
69-
case Intrinsic::experimental_convergence_entry:
73+
switch (ConvOp) {
74+
case CONV_ENTRY:
7075
Check(isInsideConvergentFunction(I),
7176
"Entry intrinsic can occur only in a convergent function.",
7277
{Context.print(&I)});
@@ -78,13 +83,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
7883
"same basic block.",
7984
{Context.print(&I)});
8085
LLVM_FALLTHROUGH;
81-
case Intrinsic::experimental_convergence_anchor:
86+
case CONV_ANCHOR:
8287
Check(!TokenDef,
8388
"Entry or anchor intrinsic cannot have a convergencectrl token "
8489
"operand.",
8590
{Context.print(&I)});
8691
break;
87-
case Intrinsic::experimental_convergence_loop:
92+
case CONV_LOOP:
8893
Check(TokenDef, "Loop intrinsic must have a convergencectrl token operand.",
8994
{Context.print(&I)});
9095
Check(!SeenFirstConvOp,
@@ -93,14 +98,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
9398
{Context.print(&I)});
9499
break;
95100
default:
96-
IsCtrlIntrinsic = false;
97101
break;
98102
}
99103

100104
if (isConvergent(I))
101105
SeenFirstConvOp = true;
102106

103-
if (TokenDef || IsCtrlIntrinsic) {
107+
if (TokenDef || ConvOp != CONV_NONE) {
104108
Check(isConvergent(I),
105109
"Convergence control token can only be used in a convergent call.",
106110
{Context.print(&I)});
@@ -161,8 +165,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
161165
return;
162166
}
163167

164-
Check(ContextT::getIntrinsicID(*User) ==
165-
Intrinsic::experimental_convergence_loop,
168+
Check(getConvOp(*User) == CONV_LOOP,
166169
"Convergence token used by an instruction other than "
167170
"llvm.experimental.convergence.loop in a cycle that does "
168171
"not contain the token's definition.",
@@ -199,7 +202,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
199202
for (auto &I : *BB) {
200203
if (auto *Token = Tokens.lookup(&I))
201204
checkToken(Token, &I, LiveTokens);
202-
if (isConvergenceControlIntrinsic(ContextT::getIntrinsicID(I)))
205+
if (getConvOp(I) != CONV_NONE)
203206
LiveTokens.push_back(&I);
204207
}
205208

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,11 @@ HANDLE_TARGET_OPCODE(MEMBARRIER)
225225
// using.
226226
HANDLE_TARGET_OPCODE(JUMP_TABLE_DEBUG_INFO)
227227

228+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_ENTRY)
229+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_ANCHOR)
230+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_LOOP)
231+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_GLUE)
232+
228233
/// The following generic opcodes are not supposed to appear after ISel.
229234
/// This is something we might want to relax, but for now, this is convenient
230235
/// to produce diagnostics.

llvm/include/llvm/Target/Target.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,25 @@ def JUMP_TABLE_DEBUG_INFO : StandardPseudoInstruction {
14831483
let isMeta = true;
14841484
}
14851485

1486+
let hasSideEffects = false, isMeta = true, isConvergent = true in {
1487+
def CONVERGENCECTRL_ANCHOR : StandardPseudoInstruction {
1488+
let OutOperandList = (outs unknown:$dst);
1489+
let InOperandList = (ins);
1490+
}
1491+
def CONVERGENCECTRL_ENTRY : StandardPseudoInstruction {
1492+
let OutOperandList = (outs unknown:$dst);
1493+
let InOperandList = (ins);
1494+
}
1495+
def CONVERGENCECTRL_LOOP : StandardPseudoInstruction {
1496+
let OutOperandList = (outs unknown:$dst);
1497+
let InOperandList = (ins unknown:$src);
1498+
}
1499+
def CONVERGENCECTRL_GLUE : StandardPseudoInstruction {
1500+
let OutOperandList = (outs);
1501+
let InOperandList = (ins unknown:$src);
1502+
}
1503+
}
1504+
14861505
// Generic opcodes used in GlobalISel.
14871506
include "llvm/Target/GenericOpcodes.td"
14881507

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,16 @@ def assertsext : SDNode<"ISD::AssertSext", SDT_assert>;
782782
def assertzext : SDNode<"ISD::AssertZext", SDT_assert>;
783783
def assertalign : SDNode<"ISD::AssertAlign", SDT_assert>;
784784

785+
def convergencectrl_anchor : SDNode<"ISD::CONVERGENCECTRL_ANCHOR",
786+
SDTypeProfile<1, 0, [SDTCisVT<0,untyped>]>>;
787+
def convergencectrl_entry : SDNode<"ISD::CONVERGENCECTRL_ENTRY",
788+
SDTypeProfile<1, 0, [SDTCisVT<0,untyped>]>>;
789+
def convergencectrl_loop : SDNode<"ISD::CONVERGENCECTRL_LOOP",
790+
SDTypeProfile<1, 1,
791+
[SDTCisVT<0,untyped>, SDTCisVT<1,untyped>]>>;
792+
def convergencectrl_glue : SDNode<"ISD::CONVERGENCECTRL_GLUE",
793+
SDTypeProfile<0, 1, [SDTCisVT<0, untyped>]>>;
794+
785795
//===----------------------------------------------------------------------===//
786796
// Selection DAG Condition Codes
787797

llvm/lib/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ add_llvm_component_library(LLVMCodeGen
109109
MachineBranchProbabilityInfo.cpp
110110
MachineCFGPrinter.cpp
111111
MachineCombiner.cpp
112+
MachineConvergenceVerifier.cpp
112113
MachineCopyPropagation.cpp
113114
MachineCSE.cpp
114115
MachineCheckDebugify.cpp
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===- ConvergenceVerifier.cpp - Verify convergence control -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "llvm/CodeGen/MachineConvergenceVerifier.h"
11+
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
12+
#include "llvm/CodeGen/MachineDominators.h"
13+
#include "llvm/CodeGen/MachineRegisterInfo.h"
14+
#include "llvm/CodeGen/MachineSSAContext.h"
15+
#include "llvm/IR/GenericConvergenceVerifierImpl.h"
16+
17+
using namespace llvm;
18+
19+
template <>
20+
auto GenericConvergenceVerifier<MachineSSAContext>::getConvOp(
21+
const MachineInstr &MI) -> ConvOpKind {
22+
switch (MI.getOpcode()) {
23+
default:
24+
return CONV_NONE;
25+
case TargetOpcode::CONVERGENCECTRL_ENTRY:
26+
return CONV_ENTRY;
27+
case TargetOpcode::CONVERGENCECTRL_ANCHOR:
28+
return CONV_ANCHOR;
29+
case TargetOpcode::CONVERGENCECTRL_LOOP:
30+
return CONV_LOOP;
31+
}
32+
}
33+
34+
template <>
35+
const MachineInstr *
36+
GenericConvergenceVerifier<MachineSSAContext>::findAndCheckConvergenceTokenUsed(
37+
const MachineInstr &MI) {
38+
const MachineRegisterInfo &MRI = Context.getFunction()->getRegInfo();
39+
const MachineInstr *TokenDef = nullptr;
40+
41+
for (const MachineOperand &MO : MI.uses()) {
42+
if (!MO.isReg())
43+
continue;
44+
Register OpReg = MO.getReg();
45+
if (!OpReg.isVirtual())
46+
continue;
47+
48+
const MachineInstr *Def = MRI.getVRegDef(OpReg);
49+
if (!Def)
50+
continue;
51+
if (getConvOp(*Def) == CONV_NONE)
52+
continue;
53+
54+
CheckOrNull(
55+
MI.isConvergent(),
56+
"Convergence control tokens can only be used by convergent operations.",
57+
{Context.print(OpReg), Context.print(&MI)});
58+
59+
CheckOrNull(!TokenDef,
60+
"An operation can use at most one convergence control token.",
61+
{Context.print(OpReg), Context.print(&MI)});
62+
63+
TokenDef = Def;
64+
}
65+
66+
if (TokenDef)
67+
Tokens[&MI] = TokenDef;
68+
69+
return TokenDef;
70+
}
71+
72+
template <>
73+
bool GenericConvergenceVerifier<MachineSSAContext>::isInsideConvergentFunction(
74+
const MachineInstr &MI) {
75+
// The class MachineFunction does not have any property to indicate whether it
76+
// is convergent. Trivially return true so that the check always passes.
77+
return true;
78+
}
79+
80+
template <>
81+
bool GenericConvergenceVerifier<MachineSSAContext>::isConvergent(
82+
const MachineInstr &MI) {
83+
return MI.isConvergent();
84+
}
85+
86+
template class llvm::GenericConvergenceVerifier<MachineSSAContext>;

0 commit comments

Comments
 (0)