Skip to content

Commit c7fdd8c

Browse files
committed
Restore "Implement convergence control in MIR using SelectionDAG (llvm#71785)"
Original commit 7988973. Perviously reverted in commit a2afcd5. LLVM function calls carry convergence control tokens as operand bundles, where the tokens themselves are produced by convergence control intrinsics. This patch implements convergence control tokens in MIR as follows: 1. Introduce target-independent ISD opcodes and MIR opcodes for convergence control intrinsics. 2. Model token values as untyped virtual registers in MIR. The change also introduces an additional ISD opcode CONVERGENCECTRL_GLUE and a corresponding machine opcode with the same spelling. This glues the convergence control token to SDNodes that represent calls to intrinsics. The glued token is later translated to an implicit argument in the MIR. The lowering of calls to user-defined functions is target-specific. On AMDGPU, the convergence control operand bundle at a non-intrinsic call is translated to an explicit argument to the SI_CALL_ISEL instruction. Post-selection adjustment converts this explicit argument to an implicit argument on the SI_CALL instruction.
1 parent 63725ab commit c7fdd8c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+832
-161
lines changed

llvm/include/llvm/ADT/GenericConvergenceVerifier.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ template <typename ContextT> class GenericConvergenceVerifier {
6060
NoConvergence
6161
} ConvergenceKind = NoConvergence;
6262

63+
/// The control token operation performed by a convergence control Intrinsic
64+
/// in LLVM IR, or by a CONVERGENCECTRL* instruction in MIR
65+
enum ConvOpKind { CONV_ANCHOR, CONV_ENTRY, CONV_LOOP, CONV_NONE };
66+
6367
// Cache token uses found so far. Note that we track the unique definitions
6468
// and not the token values.
6569
DenseMap<const InstructionT *, const InstructionT *> Tokens;
@@ -68,6 +72,8 @@ template <typename ContextT> class GenericConvergenceVerifier {
6872

6973
static bool isInsideConvergentFunction(const InstructionT &I);
7074
static bool isConvergent(const InstructionT &I);
75+
static ConvOpKind getConvOp(const InstructionT &I);
76+
void checkConvergenceTokenProduced(const InstructionT &I);
7177
const InstructionT *findAndCheckConvergenceTokenUsed(const InstructionT &I);
7278

7379
void reportFailure(const Twine &Message, ArrayRef<Printable> Values);

llvm/include/llvm/CodeGen/FunctionLoweringInfo.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,7 @@ class FunctionLoweringInfo {
215215

216216
Register CreateRegs(Type *Ty, bool isDivergent = false);
217217

218-
Register InitializeRegForValue(const Value *V) {
219-
// Tokens never live in vregs.
220-
if (V->getType()->isTokenTy())
221-
return 0;
222-
Register &R = ValueMap[V];
223-
assert(R == 0 && "Already initialized this value register!");
224-
assert(VirtReg2Value.empty());
225-
return R = CreateRegs(V);
226-
}
218+
Register InitializeRegForValue(const Value *V);
227219

228220
/// GetLiveOutRegInfo - Gets LiveOutInfo for a register, returning NULL if the
229221
/// register is a PHI destination and the PHI's LiveOutInfo is not valid.

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,15 @@ enum NodeType {
13861386
#define BEGIN_REGISTER_VP_SDNODE(VPSDID, ...) VPSDID,
13871387
#include "llvm/IR/VPIntrinsics.def"
13881388

1389+
// The `llvm.experimental.convergence.*` intrinsics.
1390+
CONVERGENCECTRL_ANCHOR,
1391+
CONVERGENCECTRL_ENTRY,
1392+
CONVERGENCECTRL_LOOP,
1393+
// This does not correspond to any convergence control intrinsic. It used to
1394+
// glue a convergence control token to a convergent operation in the DAG,
1395+
// which is later translated to an implicit use in the MIR.
1396+
CONVERGENCECTRL_GLUE,
1397+
13891398
/// BUILTIN_OP_END - This must be the last enum value in this list.
13901399
/// The target-specific pre-isel opcode values start here.
13911400
BUILTIN_OP_END
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- MachineConvergenceVerifier.h - Verify convergenctrl ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
/// \file
9+
///
10+
/// This file declares the MIR specialization of the GenericConvergenceVerifier
11+
/// template.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
16+
#define LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
17+
18+
#include "llvm/ADT/GenericConvergenceVerifier.h"
19+
#include "llvm/CodeGen/MachineSSAContext.h"
20+
21+
namespace llvm {
22+
23+
using MachineConvergenceVerifier =
24+
GenericConvergenceVerifier<MachineSSAContext>;
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H

llvm/include/llvm/CodeGen/SelectionDAGISel.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,10 @@ class SelectionDAGISel : public MachineFunctionPass {
459459
void Select_ARITH_FENCE(SDNode *N);
460460
void Select_MEMBARRIER(SDNode *N);
461461

462+
void Select_CONVERGENCECTRL_ANCHOR(SDNode *N);
463+
void Select_CONVERGENCECTRL_ENTRY(SDNode *N);
464+
void Select_CONVERGENCECTRL_LOOP(SDNode *N);
465+
462466
void pushStackMapLiveVariable(SmallVectorImpl<SDValue> &Ops, SDValue Operand,
463467
SDLoc DL);
464468
void Select_STACKMAP(SDNode *N);

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4402,6 +4402,7 @@ class TargetLowering : public TargetLoweringBase {
44024402
SmallVector<ISD::InputArg, 32> Ins;
44034403
SmallVector<SDValue, 4> InVals;
44044404
const ConstantInt *CFIType = nullptr;
4405+
SDValue ConvergenceControlToken;
44054406

44064407
CallLoweringInfo(SelectionDAG &DAG)
44074408
: RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
@@ -4535,6 +4536,11 @@ class TargetLowering : public TargetLoweringBase {
45354536
return *this;
45364537
}
45374538

4539+
CallLoweringInfo &setConvergenceControlToken(SDValue Token) {
4540+
ConvergenceControlToken = Token;
4541+
return *this;
4542+
}
4543+
45384544
ArgListTy &getArgs() {
45394545
return Args;
45404546
}

llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ void GenericConvergenceVerifier<ContextT>::visit(const BlockT &BB) {
6161

6262
template <class ContextT>
6363
void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
64-
auto ID = ContextT::getIntrinsicID(I);
65-
auto *TokenDef = findAndCheckConvergenceTokenUsed(I);
66-
bool IsCtrlIntrinsic = true;
64+
ConvOpKind ConvOp = getConvOp(I);
6765

68-
switch (ID) {
69-
case Intrinsic::experimental_convergence_entry:
66+
auto *TokenDef = findAndCheckConvergenceTokenUsed(I);
67+
switch (ConvOp) {
68+
case CONV_ENTRY:
7069
Check(isInsideConvergentFunction(I),
7170
"Entry intrinsic can occur only in a convergent function.",
7271
{Context.print(&I)});
@@ -78,13 +77,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
7877
"same basic block.",
7978
{Context.print(&I)});
8079
LLVM_FALLTHROUGH;
81-
case Intrinsic::experimental_convergence_anchor:
80+
case CONV_ANCHOR:
8281
Check(!TokenDef,
8382
"Entry or anchor intrinsic cannot have a convergencectrl token "
8483
"operand.",
8584
{Context.print(&I)});
8685
break;
87-
case Intrinsic::experimental_convergence_loop:
86+
case CONV_LOOP:
8887
Check(TokenDef, "Loop intrinsic must have a convergencectrl token operand.",
8988
{Context.print(&I)});
9089
Check(!SeenFirstConvOp,
@@ -93,14 +92,16 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
9392
{Context.print(&I)});
9493
break;
9594
default:
96-
IsCtrlIntrinsic = false;
9795
break;
9896
}
9997

98+
if (ConvOp != CONV_NONE)
99+
checkConvergenceTokenProduced(I);
100+
100101
if (isConvergent(I))
101102
SeenFirstConvOp = true;
102103

103-
if (TokenDef || IsCtrlIntrinsic) {
104+
if (TokenDef || ConvOp != CONV_NONE) {
104105
Check(isConvergent(I),
105106
"Convergence control token can only be used in a convergent call.",
106107
{Context.print(&I)});
@@ -143,6 +144,10 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
143144

144145
auto checkToken = [&](const InstructionT *Token, const InstructionT *User,
145146
SmallVectorImpl<const InstructionT *> &LiveTokens) {
147+
Check(DT.dominates(Token->getParent(), User->getParent()),
148+
"Convergence control token must dominate all its uses.",
149+
{Context.print(Token), Context.print(User)});
150+
146151
Check(llvm::is_contained(LiveTokens, Token),
147152
"Convergence region is not well-nested.",
148153
{Context.print(Token), Context.print(User)});
@@ -161,8 +166,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
161166
return;
162167
}
163168

164-
Check(ContextT::getIntrinsicID(*User) ==
165-
Intrinsic::experimental_convergence_loop,
169+
Check(getConvOp(*User) == CONV_LOOP,
166170
"Convergence token used by an instruction other than "
167171
"llvm.experimental.convergence.loop in a cycle that does "
168172
"not contain the token's definition.",
@@ -199,7 +203,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
199203
for (auto &I : *BB) {
200204
if (auto *Token = Tokens.lookup(&I))
201205
checkToken(Token, &I, LiveTokens);
202-
if (isConvergenceControlIntrinsic(ContextT::getIntrinsicID(I)))
206+
if (getConvOp(I) != CONV_NONE)
203207
LiveTokens.push_back(&I);
204208
}
205209

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,11 @@ HANDLE_TARGET_OPCODE(MEMBARRIER)
225225
// using.
226226
HANDLE_TARGET_OPCODE(JUMP_TABLE_DEBUG_INFO)
227227

228+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_ENTRY)
229+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_ANCHOR)
230+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_LOOP)
231+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_GLUE)
232+
228233
/// The following generic opcodes are not supposed to appear after ISel.
229234
/// This is something we might want to relax, but for now, this is convenient
230235
/// to produce diagnostics.

llvm/include/llvm/Target/Target.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,25 @@ def JUMP_TABLE_DEBUG_INFO : StandardPseudoInstruction {
14831483
let isMeta = true;
14841484
}
14851485

1486+
let hasSideEffects = false, isMeta = true, isConvergent = true in {
1487+
def CONVERGENCECTRL_ANCHOR : StandardPseudoInstruction {
1488+
let OutOperandList = (outs unknown:$dst);
1489+
let InOperandList = (ins);
1490+
}
1491+
def CONVERGENCECTRL_ENTRY : StandardPseudoInstruction {
1492+
let OutOperandList = (outs unknown:$dst);
1493+
let InOperandList = (ins);
1494+
}
1495+
def CONVERGENCECTRL_LOOP : StandardPseudoInstruction {
1496+
let OutOperandList = (outs unknown:$dst);
1497+
let InOperandList = (ins unknown:$src);
1498+
}
1499+
def CONVERGENCECTRL_GLUE : StandardPseudoInstruction {
1500+
let OutOperandList = (outs);
1501+
let InOperandList = (ins unknown:$src);
1502+
}
1503+
}
1504+
14861505
// Generic opcodes used in GlobalISel.
14871506
include "llvm/Target/GenericOpcodes.td"
14881507

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,16 @@ def assertsext : SDNode<"ISD::AssertSext", SDT_assert>;
789789
def assertzext : SDNode<"ISD::AssertZext", SDT_assert>;
790790
def assertalign : SDNode<"ISD::AssertAlign", SDT_assert>;
791791

792+
def convergencectrl_anchor : SDNode<"ISD::CONVERGENCECTRL_ANCHOR",
793+
SDTypeProfile<1, 0, [SDTCisVT<0,untyped>]>>;
794+
def convergencectrl_entry : SDNode<"ISD::CONVERGENCECTRL_ENTRY",
795+
SDTypeProfile<1, 0, [SDTCisVT<0,untyped>]>>;
796+
def convergencectrl_loop : SDNode<"ISD::CONVERGENCECTRL_LOOP",
797+
SDTypeProfile<1, 1,
798+
[SDTCisVT<0,untyped>, SDTCisVT<1,untyped>]>>;
799+
def convergencectrl_glue : SDNode<"ISD::CONVERGENCECTRL_GLUE",
800+
SDTypeProfile<0, 1, [SDTCisVT<0, untyped>]>>;
801+
792802
//===----------------------------------------------------------------------===//
793803
// Selection DAG Condition Codes
794804

llvm/lib/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ add_llvm_component_library(LLVMCodeGen
110110
MachineBranchProbabilityInfo.cpp
111111
MachineCFGPrinter.cpp
112112
MachineCombiner.cpp
113+
MachineConvergenceVerifier.cpp
113114
MachineCopyPropagation.cpp
114115
MachineCSE.cpp
115116
MachineCheckDebugify.cpp
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//===- ConvergenceVerifier.cpp - Verify convergence control -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "llvm/CodeGen/MachineConvergenceVerifier.h"
11+
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
12+
#include "llvm/CodeGen/MachineDominators.h"
13+
#include "llvm/CodeGen/MachineRegisterInfo.h"
14+
#include "llvm/CodeGen/MachineSSAContext.h"
15+
#include "llvm/IR/GenericConvergenceVerifierImpl.h"
16+
17+
using namespace llvm;
18+
19+
template <>
20+
auto GenericConvergenceVerifier<MachineSSAContext>::getConvOp(
21+
const MachineInstr &MI) -> ConvOpKind {
22+
switch (MI.getOpcode()) {
23+
default:
24+
return CONV_NONE;
25+
case TargetOpcode::CONVERGENCECTRL_ENTRY:
26+
return CONV_ENTRY;
27+
case TargetOpcode::CONVERGENCECTRL_ANCHOR:
28+
return CONV_ANCHOR;
29+
case TargetOpcode::CONVERGENCECTRL_LOOP:
30+
return CONV_LOOP;
31+
}
32+
}
33+
34+
template <>
35+
void GenericConvergenceVerifier<
36+
MachineSSAContext>::checkConvergenceTokenProduced(const MachineInstr &MI) {
37+
Check(!MI.hasImplicitDef(),
38+
"Convergence control tokens are defined explicitly.",
39+
{Context.print(&MI)});
40+
const MachineOperand &Def = MI.getOperand(0);
41+
const MachineRegisterInfo &MRI = Context.getFunction()->getRegInfo();
42+
Check(MRI.getUniqueVRegDef(Def.getReg()),
43+
"Convergence control tokens must have unique definitions.",
44+
{Context.print(&MI)});
45+
}
46+
47+
template <>
48+
const MachineInstr *
49+
GenericConvergenceVerifier<MachineSSAContext>::findAndCheckConvergenceTokenUsed(
50+
const MachineInstr &MI) {
51+
const MachineRegisterInfo &MRI = Context.getFunction()->getRegInfo();
52+
const MachineInstr *TokenDef = nullptr;
53+
54+
for (const MachineOperand &MO : MI.uses()) {
55+
if (!MO.isReg())
56+
continue;
57+
Register OpReg = MO.getReg();
58+
if (!OpReg.isVirtual())
59+
continue;
60+
61+
const MachineInstr *Def = MRI.getUniqueVRegDef(OpReg);
62+
if (!Def)
63+
continue;
64+
if (getConvOp(*Def) == CONV_NONE)
65+
continue;
66+
67+
CheckOrNull(
68+
MI.isConvergent(),
69+
"Convergence control tokens can only be used by convergent operations.",
70+
{Context.print(OpReg), Context.print(&MI)});
71+
72+
CheckOrNull(!TokenDef,
73+
"An operation can use at most one convergence control token.",
74+
{Context.print(OpReg), Context.print(&MI)});
75+
76+
TokenDef = Def;
77+
}
78+
79+
if (TokenDef)
80+
Tokens[&MI] = TokenDef;
81+
82+
return TokenDef;
83+
}
84+
85+
template <>
86+
bool GenericConvergenceVerifier<MachineSSAContext>::isInsideConvergentFunction(
87+
const MachineInstr &MI) {
88+
// The class MachineFunction does not have any property to indicate whether it
89+
// is convergent. Trivially return true so that the check always passes.
90+
return true;
91+
}
92+
93+
template <>
94+
bool GenericConvergenceVerifier<MachineSSAContext>::isConvergent(
95+
const MachineInstr &MI) {
96+
return MI.isConvergent();
97+
}
98+
99+
template class llvm::GenericConvergenceVerifier<MachineSSAContext>;

0 commit comments

Comments
 (0)