Skip to content

Commit ad027c0

Browse files
committed
[RFC] implement convergence control in MIR using SelectionDAG
LLVM function calls carry convergence control tokens as operand bundles, where the tokens themselves are produced by convergence control intrinsics. This patch implements convergence control tokens in MIR as follows: 1. Introduce target-independent ISD opcodes and MIR opcodes for convergence control intrinsics. 2. Model token values as untyped virtual registers in MIR. The actual lowering of controlled convergent operations (including function calls) and convergence control intrinsics is target-specific. On AMDGPU, the convergence control operand bundle at a non-intrinsic call is translated to an explicit argument to the SI_CALL_ISEL instruction. Post-selection adjustment converts this explicit argument to an implicit argument on the SI_CALL instruction. For intrinsics, AMDGPU introduces an AMDGPUISD opcode CONVERGENCECTRL_GLUE and a corresponding machine opcode with the same spelling. This is used as a glued argument to SDNodes that is later translated to an implicit argument in the MIR. All convergent intrinsics on AMDGPU need to set hasPostISelHook in their description.
1 parent 8f78dd4 commit ad027c0

File tree

71 files changed

+895
-204
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+895
-204
lines changed

llvm/include/llvm/ADT/GenericConvergenceVerifier.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ template <typename ContextT> class GenericConvergenceVerifier {
3232

3333
void initialize(raw_ostream *OS,
3434
function_ref<void(const Twine &Message)> FailureCB,
35-
const FunctionT &F) {
35+
const FunctionT &F, bool TokensAllowed_) {
3636
clear();
3737
this->OS = OS;
3838
this->FailureCB = FailureCB;
3939
Context = ContextT(&F);
40+
TokensAllowed = TokensAllowed_;
4041
}
4142

4243
void clear();
@@ -52,6 +53,7 @@ template <typename ContextT> class GenericConvergenceVerifier {
5253
DominatorTreeT *DT;
5354
CycleInfoT CI;
5455
ContextT Context;
56+
bool TokensAllowed;
5557

5658
/// Whether the current function has convergencectrl operand bundles.
5759
enum {
@@ -60,6 +62,10 @@ template <typename ContextT> class GenericConvergenceVerifier {
6062
NoConvergence
6163
} ConvergenceKind = NoConvergence;
6264

65+
/// The control token operation performed by a convergence control Intrinsic
66+
/// in LLVM IR, or by a CONVERGENCECTRL* instruction in MIR
67+
enum ConvOpKind { CONV_ANCHOR, CONV_ENTRY, CONV_LOOP, CONV_NONE };
68+
6369
// Cache token uses found so far. Note that we track the unique definitions
6470
// and not the token values.
6571
DenseMap<const InstructionT *, const InstructionT *> Tokens;
@@ -68,6 +74,7 @@ template <typename ContextT> class GenericConvergenceVerifier {
6874

6975
static bool isInsideConvergentFunction(const InstructionT &I);
7076
static bool isConvergent(const InstructionT &I);
77+
static ConvOpKind getConvOp(const InstructionT &I);
7178
const InstructionT *findAndCheckConvergenceTokenUsed(const InstructionT &I);
7279

7380
void reportFailure(const Twine &Message, ArrayRef<Printable> Values);

llvm/include/llvm/CodeGen/FunctionLoweringInfo.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,7 @@ class FunctionLoweringInfo {
215215

216216
Register CreateRegs(Type *Ty, bool isDivergent = false);
217217

218-
Register InitializeRegForValue(const Value *V) {
219-
// Tokens never live in vregs.
220-
if (V->getType()->isTokenTy())
221-
return 0;
222-
Register &R = ValueMap[V];
223-
assert(R == 0 && "Already initialized this value register!");
224-
assert(VirtReg2Value.empty());
225-
return R = CreateRegs(V);
226-
}
218+
Register InitializeRegForValue(const Value *V);
227219

228220
/// GetLiveOutRegInfo - Gets LiveOutInfo for a register, returning NULL if the
229221
/// register is a PHI destination and the PHI's LiveOutInfo is not valid.

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,10 @@ enum NodeType {
13781378
#define BEGIN_REGISTER_VP_SDNODE(VPSDID, ...) VPSDID,
13791379
#include "llvm/IR/VPIntrinsics.def"
13801380

1381+
CONVERGENCECTRL_ANCHOR,
1382+
CONVERGENCECTRL_ENTRY,
1383+
CONVERGENCECTRL_LOOP,
1384+
13811385
/// BUILTIN_OP_END - This must be the last enum value in this list.
13821386
/// The target-specific pre-isel opcode values start here.
13831387
BUILTIN_OP_END
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- MachineConvergenceVerifier.h - Verify convergenctrl ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
/// \file
9+
///
10+
/// This file declares the MIR specialization of the GenericConvergenceVerifier
11+
/// template.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
16+
#define LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
17+
18+
#include "llvm/ADT/GenericConvergenceVerifier.h"
19+
#include "llvm/CodeGen/MachineSSAContext.h"
20+
21+
namespace llvm {
22+
23+
using MachineConvergenceVerifier =
24+
GenericConvergenceVerifier<MachineSSAContext>;
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H

llvm/include/llvm/CodeGen/SelectionDAGISel.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ class SelectionDAGISel : public MachineFunctionPass {
435435
void Select_ARITH_FENCE(SDNode *N);
436436
void Select_MEMBARRIER(SDNode *N);
437437

438+
void Select_CONVERGENCECTRL_ANCHOR(SDNode *N);
439+
void Select_CONVERGENCECTRL_ENTRY(SDNode *N);
440+
void Select_CONVERGENCECTRL_LOOP(SDNode *N);
441+
438442
void pushStackMapLiveVariable(SmallVectorImpl<SDValue> &Ops, SDValue Operand,
439443
SDLoc DL);
440444
void Select_STACKMAP(SDNode *N);

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4409,6 +4409,7 @@ class TargetLowering : public TargetLoweringBase {
44094409
SmallVector<ISD::InputArg, 32> Ins;
44104410
SmallVector<SDValue, 4> InVals;
44114411
const ConstantInt *CFIType = nullptr;
4412+
SDValue ConvergenceControlToken;
44124413

44134414
CallLoweringInfo(SelectionDAG &DAG)
44144415
: RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
@@ -4542,6 +4543,11 @@ class TargetLowering : public TargetLoweringBase {
45424543
return *this;
45434544
}
45444545

4546+
CallLoweringInfo &setConvergenceControlToken(SDValue Token) {
4547+
ConvergenceControlToken = Token;
4548+
return *this;
4549+
}
4550+
45454551
ArgListTy &getArgs() {
45464552
return Args;
45474553
}
@@ -4928,9 +4934,9 @@ class TargetLowering : public TargetLoweringBase {
49284934

49294935
// Targets may override this function to collect operands from the CallInst
49304936
// and for example, lower them into the SelectionDAG operands.
4931-
virtual void CollectTargetIntrinsicOperands(const CallInst &I,
4932-
SmallVectorImpl<SDValue> &Ops,
4933-
SelectionDAG &DAG) const;
4937+
virtual void CollectTargetIntrinsicOperands(
4938+
const CallInst &I, SmallVectorImpl<SDValue> &Ops, SelectionDAG &DAG,
4939+
function_ref<SDValue(const Value *)> getValue) const;
49344940

49354941
//===--------------------------------------------------------------------===//
49364942
// Div utility functions
@@ -5362,8 +5368,9 @@ class TargetLowering : public TargetLoweringBase {
53625368
/// the 'hasPostISelHook' flag. These instructions must be adjusted after
53635369
/// instruction selection by target hooks. e.g. To fill in optional defs for
53645370
/// ARM 's' setting instructions.
5365-
virtual void AdjustInstrPostInstrSelection(MachineInstr &MI,
5366-
SDNode *Node) const;
5371+
virtual void
5372+
AdjustInstrPostInstrSelection(MachineInstr &MI, SDNode *Node,
5373+
function_ref<Register(SDValue)> getVR) const;
53675374

53685375
/// If this function returns true, SelectionDAGBuilder emits a
53695376
/// LOAD_STACK_GUARD node when it is lowering Intrinsic::stackprotector.

llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ template <class ContextT> void GenericConvergenceVerifier<ContextT>::clear() {
5252
Tokens.clear();
5353
CI.clear();
5454
ConvergenceKind = NoConvergence;
55+
TokensAllowed = false;
5556
}
5657

5758
template <class ContextT>
@@ -61,12 +62,16 @@ void GenericConvergenceVerifier<ContextT>::visit(const BlockT &BB) {
6162

6263
template <class ContextT>
6364
void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
64-
auto ID = ContextT::getIntrinsicID(I);
65+
ConvOpKind ConvOp = getConvOp(I);
66+
if (!TokensAllowed) {
67+
Check(ConvOp == CONV_NONE, "Convergence control requires SSA.",
68+
{Context.print(&I)});
69+
return;
70+
}
6571
auto *TokenDef = findAndCheckConvergenceTokenUsed(I);
66-
bool IsCtrlIntrinsic = true;
6772

68-
switch (ID) {
69-
case Intrinsic::experimental_convergence_entry:
73+
switch (ConvOp) {
74+
case CONV_ENTRY:
7075
Check(isInsideConvergentFunction(I),
7176
"Entry intrinsic can occur only in a convergent function.",
7277
{Context.print(&I)});
@@ -78,13 +83,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
7883
"same basic block.",
7984
{Context.print(&I)});
8085
LLVM_FALLTHROUGH;
81-
case Intrinsic::experimental_convergence_anchor:
86+
case CONV_ANCHOR:
8287
Check(!TokenDef,
8388
"Entry or anchor intrinsic cannot have a convergencectrl token "
8489
"operand.",
8590
{Context.print(&I)});
8691
break;
87-
case Intrinsic::experimental_convergence_loop:
92+
case CONV_LOOP:
8893
Check(TokenDef, "Loop intrinsic must have a convergencectrl token operand.",
8994
{Context.print(&I)});
9095
Check(!SeenFirstConvOp,
@@ -93,14 +98,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
9398
{Context.print(&I)});
9499
break;
95100
default:
96-
IsCtrlIntrinsic = false;
97101
break;
98102
}
99103

100104
if (isConvergent(I))
101105
SeenFirstConvOp = true;
102106

103-
if (TokenDef || IsCtrlIntrinsic) {
107+
if (TokenDef || ConvOp != CONV_NONE) {
104108
Check(isConvergent(I),
105109
"Convergence control token can only be used in a convergent call.",
106110
{Context.print(&I)});
@@ -161,8 +165,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
161165
return;
162166
}
163167

164-
Check(ContextT::getIntrinsicID(*User) ==
165-
Intrinsic::experimental_convergence_loop,
168+
Check(getConvOp(*User) == CONV_LOOP,
166169
"Convergence token used by an instruction other than "
167170
"llvm.experimental.convergence.loop in a cycle that does "
168171
"not contain the token's definition.",
@@ -199,7 +202,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
199202
for (auto &I : *BB) {
200203
if (auto *Token = Tokens.lookup(&I))
201204
checkToken(Token, &I, LiveTokens);
202-
if (isConvergenceControlIntrinsic(ContextT::getIntrinsicID(I)))
205+
if (getConvOp(I) != CONV_NONE)
203206
LiveTokens.push_back(&I);
204207
}
205208

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ HANDLE_TARGET_OPCODE(MEMBARRIER)
230230
// using.
231231
HANDLE_TARGET_OPCODE(JUMP_TABLE_DEBUG_INFO)
232232

233+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_ENTRY)
234+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_ANCHOR)
235+
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_LOOP)
236+
233237
/// The following generic opcodes are not supposed to appear after ISel.
234238
/// This is something we might want to relax, but for now, this is convenient
235239
/// to produce diagnostics.

llvm/include/llvm/Target/Target.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,34 @@ def JUMP_TABLE_DEBUG_INFO : StandardPseudoInstruction {
14371437
let isMeta = true;
14381438
}
14391439

1440+
def CONVERGENCECTRL_ANCHOR : StandardPseudoInstruction {
1441+
let OutOperandList = (outs unknown:$dst);
1442+
let InOperandList = (ins);
1443+
let AsmString = "";
1444+
let hasSideEffects = false;
1445+
let Size = 0;
1446+
let isMeta = true;
1447+
let isConvergent = true;
1448+
}
1449+
def CONVERGENCECTRL_ENTRY : StandardPseudoInstruction {
1450+
let OutOperandList = (outs unknown:$dst);
1451+
let InOperandList = (ins);
1452+
let AsmString = "";
1453+
let hasSideEffects = false;
1454+
let Size = 0;
1455+
let isMeta = true;
1456+
let isConvergent = true;
1457+
}
1458+
def CONVERGENCECTRL_LOOP : StandardPseudoInstruction {
1459+
let OutOperandList = (outs unknown:$dst);
1460+
let InOperandList = (ins unknown:$src);
1461+
let AsmString = "";
1462+
let hasSideEffects = false;
1463+
let Size = 0;
1464+
let isMeta = true;
1465+
let isConvergent = true;
1466+
}
1467+
14401468
// Generic opcodes used in GlobalISel.
14411469
include "llvm/Target/GenericOpcodes.td"
14421470

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,14 @@ def assertsext : SDNode<"ISD::AssertSext", SDT_assert>;
779779
def assertzext : SDNode<"ISD::AssertZext", SDT_assert>;
780780
def assertalign : SDNode<"ISD::AssertAlign", SDT_assert>;
781781

782+
def convergencectrl_anchor : SDNode<"ISD::CONVERGENCECTRL_ANCHOR",
783+
SDTypeProfile<1, 0, [SDTCisVT<0,untyped>]>>;
784+
def convergencectrl_entry : SDNode<"ISD::CONVERGENCECTRL_ENTRY",
785+
SDTypeProfile<1, 0, [SDTCisVT<0,untyped>]>>;
786+
def convergencectrl_loop : SDNode<"ISD::CONVERGENCECTRL_LOOP",
787+
SDTypeProfile<1, 1,
788+
[SDTCisVT<0,untyped>, SDTCisVT<1,untyped>]>>;
789+
782790
//===----------------------------------------------------------------------===//
783791
// Selection DAG Condition Codes
784792

llvm/lib/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ add_llvm_component_library(LLVMCodeGen
122122
MachineBranchProbabilityInfo.cpp
123123
MachineCFGPrinter.cpp
124124
MachineCombiner.cpp
125+
MachineConvergenceVerifier.cpp
125126
MachineCopyPropagation.cpp
126127
MachineCSE.cpp
127128
MachineCheckDebugify.cpp
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===- ConvergenceVerifier.cpp - Verify convergence control -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "llvm/CodeGen/MachineConvergenceVerifier.h"
11+
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
12+
#include "llvm/CodeGen/MachineDominators.h"
13+
#include "llvm/CodeGen/MachineRegisterInfo.h"
14+
#include "llvm/CodeGen/MachineSSAContext.h"
15+
#include "llvm/IR/GenericConvergenceVerifierImpl.h"
16+
17+
using namespace llvm;
18+
19+
template <>
20+
auto GenericConvergenceVerifier<MachineSSAContext>::getConvOp(
21+
const MachineInstr &MI) -> ConvOpKind {
22+
switch (MI.getOpcode()) {
23+
default:
24+
return CONV_NONE;
25+
case TargetOpcode::CONVERGENCECTRL_ENTRY:
26+
return CONV_ENTRY;
27+
case TargetOpcode::CONVERGENCECTRL_ANCHOR:
28+
return CONV_ANCHOR;
29+
case TargetOpcode::CONVERGENCECTRL_LOOP:
30+
return CONV_LOOP;
31+
}
32+
}
33+
34+
template <>
35+
const MachineInstr *
36+
GenericConvergenceVerifier<MachineSSAContext>::findAndCheckConvergenceTokenUsed(
37+
const MachineInstr &MI) {
38+
const MachineRegisterInfo &MRI = Context.getFunction()->getRegInfo();
39+
const MachineInstr *TokenDef = nullptr;
40+
41+
for (const MachineOperand &MO : MI.uses()) {
42+
if (!MO.isReg())
43+
continue;
44+
Register OpReg = MO.getReg();
45+
if (!OpReg.isVirtual())
46+
continue;
47+
48+
const MachineInstr *Def = MRI.getVRegDef(OpReg);
49+
if (!Def)
50+
continue;
51+
if (getConvOp(*Def) == CONV_NONE)
52+
continue;
53+
54+
CheckOrNull(
55+
MI.isConvergent(),
56+
"Convergence control tokens can only be used by convergent operations.",
57+
{Context.print(OpReg), Context.print(&MI)});
58+
59+
CheckOrNull(!TokenDef,
60+
"An operation can use at most one convergence control token.",
61+
{Context.print(OpReg), Context.print(&MI)});
62+
63+
TokenDef = Def;
64+
}
65+
66+
if (TokenDef)
67+
Tokens[&MI] = TokenDef;
68+
69+
return TokenDef;
70+
}
71+
72+
template <>
73+
bool GenericConvergenceVerifier<MachineSSAContext>::isInsideConvergentFunction(
74+
const MachineInstr &MI) {
75+
// The class MachineFunction does not have any property to indicate whether it
76+
// is convergent. Trivially return true so that the check always passes.
77+
return true;
78+
}
79+
80+
template <>
81+
bool GenericConvergenceVerifier<MachineSSAContext>::isConvergent(
82+
const MachineInstr &MI) {
83+
return MI.isConvergent();
84+
}
85+
86+
template class llvm::GenericConvergenceVerifier<MachineSSAContext>;

0 commit comments

Comments
 (0)