Skip to content

[NVPTX] Fix code generation for trap-unreachable. #67478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3226,14 +3226,9 @@ void SelectionDAGBuilder::visitUnreachable(const UnreachableInst &I) {

// We may be able to ignore unreachable behind a noreturn call.
if (DAG.getTarget().Options.NoTrapAfterNoreturn) {
const BasicBlock &BB = *I.getParent();
if (&I != &BB.front()) {
BasicBlock::const_iterator PredI =
std::prev(BasicBlock::const_iterator(&I));
if (const CallInst *Call = dyn_cast<CallInst>(&*PredI)) {
if (Call->doesNotReturn())
return;
}
if (const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode())) {
if (Call->doesNotReturn())
return;
}
}

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
FunctionPass *createNVPTXImageOptimizerPass();
FunctionPass *createNVPTXLowerArgsPass();
FunctionPass *createNVPTXLowerAllocaPass();
FunctionPass *createNVPTXLowerUnreachablePass();
FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
bool NoTrapAfterNoreturn);
MachineFunctionPass *createNVPTXPeephole();
MachineFunctionPass *createNVPTXProxyRegErasurePass();

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3545,7 +3545,9 @@ def Callseq_End :
[(callseq_end timm:$amt1, timm:$amt2)]>;

// trap instruction
def trapinst : NVPTXInst<(outs), (ins), "trap;", [(trap)]>;
// Emit an `exit` as well to convey to ptxas that `trap` exits the CFG.
// This won't be necessary in a future version of ptxas.
def trapinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>;

// Call prototype wrapper
def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
Expand Down
48 changes: 39 additions & 9 deletions llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@
// `bar.sync` instruction happen divergently.
//
// To work around this, we add an `exit` instruction before every `unreachable`,
// as `ptxas` understands that exit terminates the CFG. Note that `trap` is not
// equivalent, and only future versions of `ptxas` will model it like `exit`.
// as `ptxas` understands that exit terminates the CFG. We do only do this if
// `unreachable` is not lowered to `trap`, which has the same effect (although
// with current versions of `ptxas` only because it is emited as `trap; exit;`).
//
//===----------------------------------------------------------------------===//

Expand All @@ -83,14 +84,19 @@ void initializeNVPTXLowerUnreachablePass(PassRegistry &);

namespace {
class NVPTXLowerUnreachable : public FunctionPass {
StringRef getPassName() const override;
bool runOnFunction(Function &F) override;
bool isLoweredToTrap(const UnreachableInst &I) const;

public:
static char ID; // Pass identification, replacement for typeid
NVPTXLowerUnreachable() : FunctionPass(ID) {}
StringRef getPassName() const override {
return "add an exit instruction before every unreachable";
}
NVPTXLowerUnreachable(bool TrapUnreachable, bool NoTrapAfterNoreturn)
: FunctionPass(ID), TrapUnreachable(TrapUnreachable),
NoTrapAfterNoreturn(NoTrapAfterNoreturn) {}

private:
bool TrapUnreachable;
bool NoTrapAfterNoreturn;
};
} // namespace

Expand All @@ -99,12 +105,33 @@ char NVPTXLowerUnreachable::ID = 1;
INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
"Lower Unreachable", false, false)

StringRef NVPTXLowerUnreachable::getPassName() const {
return "add an exit instruction before every unreachable";
}

// =============================================================================
// Returns whether a `trap` intrinsic should be emitted before I.
//
// This is a copy of the logic in SelectionDAGBuilder::visitUnreachable().
// =============================================================================
bool NVPTXLowerUnreachable::isLoweredToTrap(const UnreachableInst &I) const {
if (!TrapUnreachable)
return false;
if (!NoTrapAfterNoreturn)
return true;
const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode());
return Call && Call->doesNotReturn();
}

// =============================================================================
// Main function for this pass.
// =============================================================================
bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
// Early out iff isLoweredToTrap() always returns true.
if (TrapUnreachable && !NoTrapAfterNoreturn)
return false;

LLVMContext &C = F.getContext();
FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false);
Expand All @@ -114,13 +141,16 @@ bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
for (auto &BB : F)
for (auto &I : BB) {
if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) {
Changed = true;
if (isLoweredToTrap(*unreachableInst))
continue; // trap is emitted as `trap; exit;`.
CallInst::Create(ExitFTy, Exit, "", unreachableInst);
Changed = true;
}
}
return Changed;
}

FunctionPass *llvm::createNVPTXLowerUnreachablePass() {
return new NVPTXLowerUnreachable();
FunctionPass *llvm::createNVPTXLowerUnreachablePass(bool TrapUnreachable,
bool NoTrapAfterNoreturn) {
return new NVPTXLowerUnreachable(TrapUnreachable, NoTrapAfterNoreturn);
}
12 changes: 3 additions & 9 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,6 @@ static cl::opt<bool> UseShortPointersOpt(
"Use 32-bit pointers for accessing const/local/shared address spaces."),
cl::init(false), cl::Hidden);

// FIXME: intended as a temporary debugging aid. Should be removed before it
// makes it into the LLVM-17 release.
static cl::opt<bool>
ExitOnUnreachable("nvptx-exit-on-unreachable",
cl::desc("Lower 'unreachable' as 'exit' instruction."),
cl::init(true), cl::Hidden);

namespace llvm {

void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
Expand Down Expand Up @@ -410,8 +403,9 @@ void NVPTXPassConfig::addIRPasses() {
addPass(createSROAPass());
}

if (ExitOnUnreachable)
addPass(createNVPTXLowerUnreachablePass());
const auto &Options = getNVPTXTargetMachine().Options;
addPass(createNVPTXLowerUnreachablePass(Options.TrapUnreachable,
Options.NoTrapAfterNoreturn));
}

bool NVPTXPassConfig::addInstSelector() {
Expand Down
15 changes: 12 additions & 3 deletions llvm/test/CodeGen/NVPTX/unreachable.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs \
; RUN: | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-NOTRAP
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs \
; RUN: | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-NOTRAP
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs -trap-unreachable \
; RUN: | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-TRAP
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs -trap-unreachable \
; RUN: | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-TRAP
; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}

Expand All @@ -11,7 +17,10 @@ define void @kernel_func() {
; CHECK: call.uni
; CHECK: throw,
call void @throw()
; CHECK: exit
; CHECK-TRAP-NOT: exit;
; CHECK-TRAP: trap;
; CHECK-NOTRAP-NOT: trap;
; CHECK: exit;
unreachable
}

Expand Down