Skip to content

[NVVMReflect] Force dead branch elimination in NVVMReflect #81189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ pipeline, immediately after the link stage. The ``internalize`` pass is also
recommended to remove unused math functions from the resulting PTX. For an
input IR module ``module.bc``, the following compilation flow is recommended:

The ``NVVMReflect`` pass will attempt to remove dead code even without
optimizations. This allows potentially incompatible instructions to be avoided
at all optimizations levels. This currently only works for simple conditionals
like the above example.

1. Save list of external functions in ``module.bc``
2. Link ``module.bc`` with ``libdevice.compute_XX.YY.bc``
3. Internalize all functions not in list from (1)
Expand Down
62 changes: 62 additions & 0 deletions llvm/lib/Target/NVPTX/NVVMReflect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "NVPTX.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
Expand All @@ -36,6 +37,8 @@
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include <sstream>
#include <string>
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
Expand Down Expand Up @@ -87,6 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
}

SmallVector<Instruction *, 4> ToRemove;
SmallVector<ICmpInst *, 4> ToSimplify;

// Go through the calls in this function. Each call to __nvvm_reflect or
// llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
Expand Down Expand Up @@ -171,13 +175,71 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
} else if (ReflectArg == "__CUDA_ARCH") {
ReflectVal = SmVersion * 10;
}

// If the immediate user is a simple comparison we want to simplify it.
// TODO: This currently does not handle switch instructions.
for (User *U : Call->users())
if (ICmpInst *I = dyn_cast<ICmpInst>(U))
ToSimplify.push_back(I);

Comment on lines +181 to +184
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is overly conservative, IMO. E.g. it may not handle something like switch(__nvvm_reflect("__CUDA_ARCH")), or if ((__nvvm_reflect("__CUDA_ARCH") / 10) == 70)

I think, ideally, we may want to iterate upwards the use tree, as long as the current subtree evaluates to a constant, until we reach a switch/branch/select where we can now pick the correct branch.

On one hand it may be an overkill, and we could just document that this improvement works for if only. Even that would implicitly rely on the pass happening very early in the pipeline, before LLVM gets to possibly convert that if into select. if-only solution may be OK for now, but if would be great to make sure its behavior is consistent for other forms of conditionals.

I'm OK with the patch in this form, but we should add a TODO outlining the still missing pieces.

Maybe add a few test cases showing where it does not work at the moment.

Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
ToRemove.push_back(Call);
}

for (Instruction *I : ToRemove)
I->eraseFromParent();

// The code guarded by __nvvm_reflect may be invalid for the target machine.
// We need to do some basic dead code elimination to trim invalid code before
// it reaches the backend at all optimization levels.
SmallVector<BranchInst *> Simplified;
for (ICmpInst *Cmp : ToSimplify) {
Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1));

if (!LHS || !RHS)
continue;

// If the comparison is a compile time constant we simply propagate it.
Constant *C = ConstantFoldCompareInstOperands(
Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());

if (!C)
continue;

for (User *U : Cmp->users())
if (BranchInst *I = dyn_cast<BranchInst>(U))
Simplified.push_back(I);

Cmp->replaceAllUsesWith(C);
Cmp->eraseFromParent();
}

// Each instruction here is a conditional branch off of a constant true or
// false value. Simply replace it with an unconditional branch to the
// appropriate basic block and delete the rest if it is trivially dead.
DenseSet<Instruction *> Removed;
for (BranchInst *Branch : Simplified) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConstantFoldTerminator? I also thought there was an API to just call simplifyCFG on a single block

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it's also in Local.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed that, thanks. I'll probably make an updated version that also handles the Switch case since this function makes that trivial.

Side note, how do we handle the ROCm-DL constants? I remember we have some mandatory constant prop + DCE in a similar case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The incompatible subtarget features on each function block inlining. The incompatible functions pass deletes functions. It's a high maintenance system to keep all of those pieces consistent, but in principle incompatible code never exists in the same function

if (Removed.contains(Branch))
continue;

ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition());
if (!C || (!C->isOne() && !C->isZero()))
continue;

BasicBlock *TrueBB =
C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1);
BasicBlock *FalseBB =
C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0);

ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB));
if (FalseBB->use_empty() && FalseBB->hasNPredecessors(0) &&
FalseBB->getFirstNonPHIOrDbg()) {
Removed.insert(FalseBB->getFirstNonPHIOrDbg());
changeToUnreachable(FalseBB->getFirstNonPHIOrDbg());
}
}

return ToRemove.size() > 0;
}

Expand Down
141 changes: 141 additions & 0 deletions llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
; RUN: llc < %s -march=nvptx64 -mcpu=sm_52 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_52
; RUN: llc < %s -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_70
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx72 -O0 | FileCheck %s --check-prefix=SM_90

@.str = private unnamed_addr constant [12 x i8] c"__CUDA_ARCH\00"

declare i32 @__nvvm_reflect(ptr)

; SM_52: .visible .func (.param .b32 func_retval0) foo()
; SM_52: mov.b32 %[[REG:.+]], 3;
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_52-NEXT: ret;
;
; SM_70: .visible .func (.param .b32 func_retval0) foo()
; SM_70: mov.b32 %[[REG:.+]], 2;
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_70-NEXT: ret;
;
; SM_90: .visible .func (.param .b32 func_retval0) foo()
; SM_90: mov.b32 %[[REG:.+]], 1;
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_90-NEXT: ret;
define i32 @foo() {
entry:
%call = call i32 @__nvvm_reflect(ptr @.str)
%cmp = icmp uge i32 %call, 900
br i1 %cmp, label %if.then, label %if.else

if.then:
br label %return

if.else:
%call1 = call i32 @__nvvm_reflect(ptr @.str)
%cmp2 = icmp uge i32 %call1, 700
br i1 %cmp2, label %if.then3, label %if.else4

if.then3:
br label %return

if.else4:
%call5 = call i32 @__nvvm_reflect(ptr @.str)
%cmp6 = icmp uge i32 %call5, 520
br i1 %cmp6, label %if.then7, label %if.else8

if.then7:
br label %return

if.else8:
br label %return

return:
%retval.0 = phi i32 [ 1, %if.then ], [ 2, %if.then3 ], [ 3, %if.then7 ], [ 4, %if.else8 ]
ret i32 %retval.0
}

; SM_52: .visible .func (.param .b32 func_retval0) bar()
; SM_52: mov.b32 %[[REG:.+]], 2;
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_52-NEXT: ret;
;
; SM_70: .visible .func (.param .b32 func_retval0) bar()
; SM_70: mov.b32 %[[REG:.+]], 1;
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_70-NEXT: ret;
;
; SM_90: .visible .func (.param .b32 func_retval0) bar()
; SM_90: mov.b32 %[[REG:.+]], 1;
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_90-NEXT: ret;
define i32 @bar() {
entry:
%call = call i32 @__nvvm_reflect(ptr @.str)
%cmp = icmp uge i32 %call, 700
br i1 %cmp, label %if.then, label %if.else

if.then:
br label %if.end

if.else:
br label %if.end

if.end:
%x = phi i32 [ 1, %if.then ], [ 2, %if.else ]
ret i32 %x
}

; SM_52-NOT: valid;
; SM_70: valid;
; SM_90: valid;
define void @baz() {
entry:
%call = call i32 @__nvvm_reflect(ptr @.str)
%cmp = icmp uge i32 %call, 700
br i1 %cmp, label %if.then, label %if.end

if.then:
call void asm sideeffect "valid;\0A", ""()
br label %if.end

if.end:
ret void
}

; SM_52: .visible .func (.param .b32 func_retval0) qux()
; SM_52: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
; SM_52: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_52: ret;
; SM_70: .visible .func (.param .b32 func_retval0) qux()
; SM_70: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
; SM_70: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_70: ret;
; SM_90: .visible .func (.param .b32 func_retval0) qux()
; SM_90: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_90: ret;
define i32 @qux() {
entry:
%call = call i32 @__nvvm_reflect(ptr noundef @.str)
%cmp = icmp uge i32 %call, 700
%conv = zext i1 %cmp to i32
switch i32 %conv, label %sw.default [
i32 900, label %sw.bb
i32 700, label %sw.bb1
i32 520, label %sw.bb2
]

sw.bb:
br label %return

sw.bb1:
br label %return

sw.bb2:
br label %return

sw.default:
br label %return

return:
%retval = phi i32 [ 4, %sw.default ], [ 3, %sw.bb2 ], [ 2, %sw.bb1 ], [ 1, %sw.bb ]
ret i32 %retval
}
1 change: 0 additions & 1 deletion llvm/test/CodeGen/NVPTX/nvvm-reflect-arch.ll
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@ define i32 @foo(float %a, float %b) {
; SM35: ret i32 350
ret i32 %reflect
}