Skip to content

Commit 6510fa0

Browse files
SC llvm teamSC llvm team
authored andcommitted
Merged main:5b7a7ec5a210 into amd-gfx:5546a043a84f
Local branch amd-gfx 5546a04 Merged main:18461dc45483 into amd-gfx:2f63a363aacc Remote branch main 5b7a7ec [NVPTX] Fix code generation for `trap-unreachable`. (llvm#67478)
2 parents 5546a04 + 5b7a7ec commit 6510fa0

File tree

7 files changed

+63
-32
lines changed

7 files changed

+63
-32
lines changed

llvm/include/llvm/Config/llvm-config.h.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
/* Indicate that this is LLVM compiled from the amd-gfx branch. */
1818
#define LLVM_HAVE_BRANCH_AMD_GFX
19-
#define LLVM_MAIN_REVISION 476469
19+
#define LLVM_MAIN_REVISION 476470
2020

2121
/* Define if LLVM_ENABLE_DUMP is enabled */
2222
#cmakedefine LLVM_ENABLE_DUMP

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3226,14 +3226,9 @@ void SelectionDAGBuilder::visitUnreachable(const UnreachableInst &I) {
32263226

32273227
// We may be able to ignore unreachable behind a noreturn call.
32283228
if (DAG.getTarget().Options.NoTrapAfterNoreturn) {
3229-
const BasicBlock &BB = *I.getParent();
3230-
if (&I != &BB.front()) {
3231-
BasicBlock::const_iterator PredI =
3232-
std::prev(BasicBlock::const_iterator(&I));
3233-
if (const CallInst *Call = dyn_cast<CallInst>(&*PredI)) {
3234-
if (Call->doesNotReturn())
3235-
return;
3236-
}
3229+
if (const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode())) {
3230+
if (Call->doesNotReturn())
3231+
return;
32373232
}
32383233
}
32393234

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
4747
FunctionPass *createNVPTXImageOptimizerPass();
4848
FunctionPass *createNVPTXLowerArgsPass();
4949
FunctionPass *createNVPTXLowerAllocaPass();
50-
FunctionPass *createNVPTXLowerUnreachablePass();
50+
FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
51+
bool NoTrapAfterNoreturn);
5152
MachineFunctionPass *createNVPTXPeephole();
5253
MachineFunctionPass *createNVPTXProxyRegErasurePass();
5354

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3563,7 +3563,9 @@ def Callseq_End :
35633563
[(callseq_end timm:$amt1, timm:$amt2)]>;
35643564

35653565
// trap instruction
3566-
def trapinst : NVPTXInst<(outs), (ins), "trap;", [(trap)]>;
3566+
// Emit an `exit` as well to convey to ptxas that `trap` exits the CFG.
3567+
// This won't be necessary in a future version of ptxas.
3568+
def trapinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>;
35673569

35683570
// Call prototype wrapper
35693571
def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;

llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@
6363
// `bar.sync` instruction happen divergently.
6464
//
6565
// To work around this, we add an `exit` instruction before every `unreachable`,
66-
// as `ptxas` understands that exit terminates the CFG. Note that `trap` is not
67-
// equivalent, and only future versions of `ptxas` will model it like `exit`.
66+
// as `ptxas` understands that exit terminates the CFG. We do only do this if
67+
// `unreachable` is not lowered to `trap`, which has the same effect (although
68+
// with current versions of `ptxas` only because it is emited as `trap; exit;`).
6869
//
6970
//===----------------------------------------------------------------------===//
7071

@@ -83,14 +84,19 @@ void initializeNVPTXLowerUnreachablePass(PassRegistry &);
8384

8485
namespace {
8586
class NVPTXLowerUnreachable : public FunctionPass {
87+
StringRef getPassName() const override;
8688
bool runOnFunction(Function &F) override;
89+
bool isLoweredToTrap(const UnreachableInst &I) const;
8790

8891
public:
8992
static char ID; // Pass identification, replacement for typeid
90-
NVPTXLowerUnreachable() : FunctionPass(ID) {}
91-
StringRef getPassName() const override {
92-
return "add an exit instruction before every unreachable";
93-
}
93+
NVPTXLowerUnreachable(bool TrapUnreachable, bool NoTrapAfterNoreturn)
94+
: FunctionPass(ID), TrapUnreachable(TrapUnreachable),
95+
NoTrapAfterNoreturn(NoTrapAfterNoreturn) {}
96+
97+
private:
98+
bool TrapUnreachable;
99+
bool NoTrapAfterNoreturn;
94100
};
95101
} // namespace
96102

@@ -99,12 +105,33 @@ char NVPTXLowerUnreachable::ID = 1;
99105
INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
100106
"Lower Unreachable", false, false)
101107

108+
StringRef NVPTXLowerUnreachable::getPassName() const {
109+
return "add an exit instruction before every unreachable";
110+
}
111+
112+
// =============================================================================
113+
// Returns whether a `trap` intrinsic should be emitted before I.
114+
//
115+
// This is a copy of the logic in SelectionDAGBuilder::visitUnreachable().
116+
// =============================================================================
117+
bool NVPTXLowerUnreachable::isLoweredToTrap(const UnreachableInst &I) const {
118+
if (!TrapUnreachable)
119+
return false;
120+
if (!NoTrapAfterNoreturn)
121+
return true;
122+
const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode());
123+
return Call && Call->doesNotReturn();
124+
}
125+
102126
// =============================================================================
103127
// Main function for this pass.
104128
// =============================================================================
105129
bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
106130
if (skipFunction(F))
107131
return false;
132+
// Early out iff isLoweredToTrap() always returns true.
133+
if (TrapUnreachable && !NoTrapAfterNoreturn)
134+
return false;
108135

109136
LLVMContext &C = F.getContext();
110137
FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false);
@@ -114,13 +141,16 @@ bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
114141
for (auto &BB : F)
115142
for (auto &I : BB) {
116143
if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) {
117-
Changed = true;
144+
if (isLoweredToTrap(*unreachableInst))
145+
continue; // trap is emitted as `trap; exit;`.
118146
CallInst::Create(ExitFTy, Exit, "", unreachableInst);
147+
Changed = true;
119148
}
120149
}
121150
return Changed;
122151
}
123152

124-
FunctionPass *llvm::createNVPTXLowerUnreachablePass() {
125-
return new NVPTXLowerUnreachable();
153+
FunctionPass *llvm::createNVPTXLowerUnreachablePass(bool TrapUnreachable,
154+
bool NoTrapAfterNoreturn) {
155+
return new NVPTXLowerUnreachable(TrapUnreachable, NoTrapAfterNoreturn);
126156
}

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,6 @@ static cl::opt<bool> UseShortPointersOpt(
6363
"Use 32-bit pointers for accessing const/local/shared address spaces."),
6464
cl::init(false), cl::Hidden);
6565

66-
// FIXME: intended as a temporary debugging aid. Should be removed before it
67-
// makes it into the LLVM-17 release.
68-
static cl::opt<bool>
69-
ExitOnUnreachable("nvptx-exit-on-unreachable",
70-
cl::desc("Lower 'unreachable' as 'exit' instruction."),
71-
cl::init(true), cl::Hidden);
72-
7366
namespace llvm {
7467

7568
void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
@@ -410,8 +403,9 @@ void NVPTXPassConfig::addIRPasses() {
410403
addPass(createSROAPass());
411404
}
412405

413-
if (ExitOnUnreachable)
414-
addPass(createNVPTXLowerUnreachablePass());
406+
const auto &Options = getNVPTXTargetMachine().Options;
407+
addPass(createNVPTXLowerUnreachablePass(Options.TrapUnreachable,
408+
Options.NoTrapAfterNoreturn));
415409
}
416410

417411
bool NVPTXPassConfig::addInstSelector() {

llvm/test/CodeGen/NVPTX/unreachable.ll

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
2-
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
1+
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs \
2+
; RUN: | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-NOTRAP
3+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs \
4+
; RUN: | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-NOTRAP
5+
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs -trap-unreachable \
6+
; RUN: | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-TRAP
7+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs -trap-unreachable \
8+
; RUN: | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-TRAP
39
; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
410
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
511

@@ -11,7 +17,10 @@ define void @kernel_func() {
1117
; CHECK: call.uni
1218
; CHECK: throw,
1319
call void @throw()
14-
; CHECK: exit
20+
; CHECK-TRAP-NOT: exit;
21+
; CHECK-TRAP: trap;
22+
; CHECK-NOTRAP-NOT: trap;
23+
; CHECK: exit;
1524
unreachable
1625
}
1726

0 commit comments

Comments
 (0)