Skip to content

Add Dead Block Elimination to NVVMReflect #144171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 96 additions & 17 deletions llvm/lib/Target/NVPTX/NVVMReflect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//===----------------------------------------------------------------------===//

#include "NVPTX.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ConstantFolding.h"
Expand Down Expand Up @@ -59,7 +60,10 @@ class NVVMReflect {
StringMap<unsigned> ReflectMap;
bool handleReflectFunction(Module &M, StringRef ReflectName);
void populateReflectMap(Module &M);
void foldReflectCall(CallInst *Call, Constant *NewValue);
void replaceReflectCalls(
SmallVector<std::pair<CallInst *, Constant *>, 8> &ReflectReplacements,
const DataLayout &DL);
SetVector<BasicBlock *> findTransitivelyDeadBlocks(BasicBlock *DeadBB);

public:
// __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
Expand Down Expand Up @@ -138,6 +142,8 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
assert(F->getReturnType()->isIntegerTy() &&
"_reflect's return type should be integer");

SmallVector<std::pair<CallInst *, Constant *>, 8> ReflectReplacements;

const bool Changed = !F->use_empty();
for (User *U : make_early_inc_range(F->users())) {
// Reflect function calls look like:
Expand Down Expand Up @@ -178,38 +184,111 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
<< "(" << ReflectArg << ") with value " << ReflectVal
<< "\n");
auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
foldReflectCall(Call, NewValue);
Call->eraseFromParent();
ReflectReplacements.push_back({Call, NewValue});
}

// Remove the __nvvm_reflect function from the module
replaceReflectCalls(ReflectReplacements, M.getDataLayout());
F->eraseFromParent();
return Changed;
}

void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
/// Find all blocks that become dead transitively from an initial dead block.
/// Returns the complete set including the original dead block and any blocks
/// that lose all their predecessors due to the deletion cascade.
SetVector<BasicBlock *>
NVVMReflect::findTransitivelyDeadBlocks(BasicBlock *DeadBB) {
SmallVector<BasicBlock *, 8> Worklist({DeadBB});
SetVector<BasicBlock *> DeadBlocks;
while (!Worklist.empty()) {
auto *BB = Worklist.pop_back_val();
DeadBlocks.insert(BB);

for (BasicBlock *Succ : successors(BB))
if (pred_size(Succ) == 1 && DeadBlocks.insert(Succ))
Worklist.push_back(Succ);
}
return DeadBlocks;
}

/// Replace calls to __nvvm_reflect with corresponding constant values. Then
/// clean up through constant folding and propagation and dead block
/// elimination.
///
/// The purpose of this cleanup is not optimization because that could be
/// handled by later passes
/// (i.e. SCCP, SimplifyCFG, etc.), but for correctness. Reflect calls are most
/// commonly used to query the arch number and select a valid instruction for
/// the arch. Therefore, you need to eliminate blocks that become dead because
/// they may contain invalid instructions for the arch. The purpose of the
/// cleanup is to do the minimal amount of work to leave the code in a valid
/// state.
void NVVMReflect::replaceReflectCalls(
SmallVector<std::pair<CallInst *, Constant *>, 8> &ReflectReplacements,
const DataLayout &DL) {
SmallVector<Instruction *, 8> Worklist;
// Replace an instruction with a constant and add all users of the instruction
// to the worklist
SetVector<BasicBlock *> DeadBlocks;

// Replace an instruction with a constant and add all users to the worklist,
// then delete the instruction
auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
for (auto *U : I->users())
if (auto *UI = dyn_cast<Instruction>(U))
Worklist.push_back(UI);
I->replaceAllUsesWith(C);
if (isInstructionTriviallyDead(I))
I->eraseFromParent();
};

ReplaceInstructionWithConst(Call, NewValue);
for (auto &[Call, NewValue] : ReflectReplacements)
ReplaceInstructionWithConst(Call, NewValue);

auto &DL = Call->getModule()->getDataLayout();
while (!Worklist.empty()) {
auto *I = Worklist.pop_back_val();
if (auto *C = ConstantFoldInstruction(I, DL)) {
ReplaceInstructionWithConst(I, C);
if (isInstructionTriviallyDead(I))
I->eraseFromParent();
} else if (I->isTerminator()) {
ConstantFoldTerminator(I->getParent());
// Alternate between constant folding/propagation and dead block elimination.
// Terminator folding may create new dead blocks. When those dead blocks are
// deleted, their live successors may have PHIs that can be simplified, which
// may yield more work for folding/propagation.
while (true) {
// Iterate folding and propagating constants until the worklist is empty.
while (!Worklist.empty()) {
auto *I = Worklist.pop_back_val();
if (auto *C = ConstantFoldInstruction(I, DL)) {
ReplaceInstructionWithConst(I, C);
} else if (I->isTerminator()) {
BasicBlock *BB = I->getParent();
SmallVector<BasicBlock *, 8> Succs(successors(BB));
// Some blocks may become dead if the terminator is folded because
// a conditional branch is turned into a direct branch.
if (ConstantFoldTerminator(BB)) {
for (BasicBlock *Succ : Succs) {
if (pred_empty(Succ) &&
Succ != &Succ->getParent()->getEntryBlock()) {
SetVector<BasicBlock *> TransitivelyDead =
findTransitivelyDeadBlocks(Succ);
DeadBlocks.insert(TransitivelyDead.begin(),
TransitivelyDead.end());
}
}
}
}
}
// No more constants to fold and no more dead blocks
// to create more work. We're done.
if (DeadBlocks.empty())
break;
// PHI nodes of live successors of dead blocks get eliminated when the dead
// blocks are eliminated. Their users can now be simplified further, so add
// them to the worklist.
for (BasicBlock *DeadBB : DeadBlocks)
for (BasicBlock *Succ : successors(DeadBB))
if (!DeadBlocks.contains(Succ))
for (PHINode &PHI : Succ->phis())
for (auto *U : PHI.users())
if (auto *UI = dyn_cast<Instruction>(U))
Worklist.push_back(UI);
// Delete all dead blocks
for (BasicBlock *DeadBB : DeadBlocks)
DeleteDeadBlock(DeadBB);

DeadBlocks.clear();
}
}

Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

; RUN: cat %s > %t.noftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz
; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK

; RUN: cat %s > %t.ftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz
; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK

@str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

; RUN: cat %s > %t.noftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz
; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK

; RUN: cat %s > %t.ftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz
; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK

@str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"
Expand Down
Loading