-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVVMReflect] Force dead branch elimination in NVVMReflect #81189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
|
||
#include "NVPTX.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/Analysis/ConstantFolding.h" | ||
#include "llvm/IR/Constants.h" | ||
#include "llvm/IR/DerivedTypes.h" | ||
#include "llvm/IR/Function.h" | ||
|
@@ -36,6 +37,8 @@ | |
#include "llvm/Support/raw_os_ostream.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
#include "llvm/Transforms/Scalar.h" | ||
#include "llvm/Transforms/Utils/BasicBlockUtils.h" | ||
#include "llvm/Transforms/Utils/Local.h" | ||
#include <sstream> | ||
#include <string> | ||
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect" | ||
|
@@ -87,6 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) { | |
} | ||
|
||
SmallVector<Instruction *, 4> ToRemove; | ||
SmallVector<ICmpInst *, 4> ToSimplify; | ||
|
||
// Go through the calls in this function. Each call to __nvvm_reflect or | ||
// llvm.nvvm.reflect should be a CallInst with a ConstantArray argument. | ||
|
@@ -171,13 +175,71 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) { | |
} else if (ReflectArg == "__CUDA_ARCH") { | ||
ReflectVal = SmVersion * 10; | ||
} | ||
|
||
// If the immediate user is a simple comparison we want to simplify it. | ||
// TODO: This currently does not handle switch instructions. | ||
for (User *U : Call->users()) | ||
if (ICmpInst *I = dyn_cast<ICmpInst>(U)) | ||
ToSimplify.push_back(I); | ||
|
||
Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal)); | ||
ToRemove.push_back(Call); | ||
} | ||
|
||
for (Instruction *I : ToRemove) | ||
I->eraseFromParent(); | ||
|
||
// The code guarded by __nvvm_reflect may be invalid for the target machine. | ||
// We need to do some basic dead code elimination to trim invalid code before | ||
// it reaches the backend at all optimization levels. | ||
SmallVector<BranchInst *> Simplified; | ||
for (ICmpInst *Cmp : ToSimplify) { | ||
Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0)); | ||
Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1)); | ||
|
||
if (!LHS || !RHS) | ||
continue; | ||
|
||
// If the comparison is a compile time constant we simply propagate it. | ||
Constant *C = ConstantFoldCompareInstOperands( | ||
Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout()); | ||
|
||
if (!C) | ||
continue; | ||
|
||
for (User *U : Cmp->users()) | ||
if (BranchInst *I = dyn_cast<BranchInst>(U)) | ||
Simplified.push_back(I); | ||
|
||
Cmp->replaceAllUsesWith(C); | ||
Cmp->eraseFromParent(); | ||
} | ||
|
||
// Each instruction here is a conditional branch off of a constant true or | ||
// false value. Simply replace it with an unconditional branch to the | ||
// appropriate basic block and delete the rest if it is trivially dead. | ||
DenseSet<Instruction *> Removed; | ||
for (BranchInst *Branch : Simplified) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ConstantFoldTerminator? I also thought there was an API to just call simplifyCFG on a single block There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, it's also in Local.h There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missed that, thanks. I'll probably make an updated version that also handles the Switch case since this function makes that trivial. Side note, how do we handle the ROCm-DL constants? I remember we have some mandatory constant prop + DCE in a similar case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The incompatible subtarget features on each function block inlining. The incompatible functions pass deletes functions. It's a high maintenance system to keep all of those pieces consistent, but in principle incompatible code never exists in the same function |
||
if (Removed.contains(Branch)) | ||
continue; | ||
|
||
ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition()); | ||
if (!C || (!C->isOne() && !C->isZero())) | ||
continue; | ||
|
||
BasicBlock *TrueBB = | ||
C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1); | ||
BasicBlock *FalseBB = | ||
C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0); | ||
|
||
ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB)); | ||
if (FalseBB->use_empty() && FalseBB->hasNPredecessors(0) && | ||
FalseBB->getFirstNonPHIOrDbg()) { | ||
Removed.insert(FalseBB->getFirstNonPHIOrDbg()); | ||
changeToUnreachable(FalseBB->getFirstNonPHIOrDbg()); | ||
} | ||
} | ||
|
||
return ToRemove.size() > 0; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
; RUN: llc < %s -march=nvptx64 -mcpu=sm_52 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_52 | ||
; RUN: llc < %s -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_70 | ||
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx72 -O0 | FileCheck %s --check-prefix=SM_90 | ||
|
||
@.str = private unnamed_addr constant [12 x i8] c"__CUDA_ARCH\00" | ||
|
||
declare i32 @__nvvm_reflect(ptr) | ||
|
||
; SM_52: .visible .func (.param .b32 func_retval0) foo() | ||
; SM_52: mov.b32 %[[REG:.+]], 3; | ||
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]]; | ||
; SM_52-NEXT: ret; | ||
; | ||
; SM_70: .visible .func (.param .b32 func_retval0) foo() | ||
; SM_70: mov.b32 %[[REG:.+]], 2; | ||
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]]; | ||
; SM_70-NEXT: ret; | ||
; | ||
; SM_90: .visible .func (.param .b32 func_retval0) foo() | ||
; SM_90: mov.b32 %[[REG:.+]], 1; | ||
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]]; | ||
; SM_90-NEXT: ret; | ||
define i32 @foo() { | ||
entry: | ||
%call = call i32 @__nvvm_reflect(ptr @.str) | ||
%cmp = icmp uge i32 %call, 900 | ||
br i1 %cmp, label %if.then, label %if.else | ||
|
||
if.then: | ||
br label %return | ||
|
||
if.else: | ||
%call1 = call i32 @__nvvm_reflect(ptr @.str) | ||
%cmp2 = icmp uge i32 %call1, 700 | ||
br i1 %cmp2, label %if.then3, label %if.else4 | ||
|
||
if.then3: | ||
br label %return | ||
|
||
if.else4: | ||
%call5 = call i32 @__nvvm_reflect(ptr @.str) | ||
%cmp6 = icmp uge i32 %call5, 520 | ||
br i1 %cmp6, label %if.then7, label %if.else8 | ||
|
||
if.then7: | ||
br label %return | ||
|
||
if.else8: | ||
br label %return | ||
|
||
return: | ||
%retval.0 = phi i32 [ 1, %if.then ], [ 2, %if.then3 ], [ 3, %if.then7 ], [ 4, %if.else8 ] | ||
ret i32 %retval.0 | ||
} | ||
|
||
; SM_52: .visible .func (.param .b32 func_retval0) bar() | ||
; SM_52: mov.b32 %[[REG:.+]], 2; | ||
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]]; | ||
; SM_52-NEXT: ret; | ||
; | ||
; SM_70: .visible .func (.param .b32 func_retval0) bar() | ||
; SM_70: mov.b32 %[[REG:.+]], 1; | ||
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]]; | ||
; SM_70-NEXT: ret; | ||
; | ||
; SM_90: .visible .func (.param .b32 func_retval0) bar() | ||
; SM_90: mov.b32 %[[REG:.+]], 1; | ||
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]]; | ||
; SM_90-NEXT: ret; | ||
define i32 @bar() { | ||
entry: | ||
%call = call i32 @__nvvm_reflect(ptr @.str) | ||
%cmp = icmp uge i32 %call, 700 | ||
br i1 %cmp, label %if.then, label %if.else | ||
|
||
if.then: | ||
br label %if.end | ||
|
||
if.else: | ||
br label %if.end | ||
|
||
if.end: | ||
%x = phi i32 [ 1, %if.then ], [ 2, %if.else ] | ||
ret i32 %x | ||
} | ||
|
||
; SM_52-NOT: valid; | ||
; SM_70: valid; | ||
; SM_90: valid; | ||
define void @baz() { | ||
entry: | ||
%call = call i32 @__nvvm_reflect(ptr @.str) | ||
%cmp = icmp uge i32 %call, 700 | ||
br i1 %cmp, label %if.then, label %if.end | ||
|
||
if.then: | ||
call void asm sideeffect "valid;\0A", ""() | ||
br label %if.end | ||
|
||
if.end: | ||
ret void | ||
} | ||
|
||
; SM_52: .visible .func (.param .b32 func_retval0) qux() | ||
; SM_52: mov.u32 %[[REG1:.+]], %[[REG2:.+]]; | ||
; SM_52: st.param.b32 [func_retval0+0], %[[REG1:.+]]; | ||
; SM_52: ret; | ||
; SM_70: .visible .func (.param .b32 func_retval0) qux() | ||
; SM_70: mov.u32 %[[REG1:.+]], %[[REG2:.+]]; | ||
; SM_70: st.param.b32 [func_retval0+0], %[[REG1:.+]]; | ||
; SM_70: ret; | ||
; SM_90: .visible .func (.param .b32 func_retval0) qux() | ||
; SM_90: st.param.b32 [func_retval0+0], %[[REG1:.+]]; | ||
; SM_90: ret; | ||
define i32 @qux() { | ||
entry: | ||
%call = call i32 @__nvvm_reflect(ptr noundef @.str) | ||
%cmp = icmp uge i32 %call, 700 | ||
%conv = zext i1 %cmp to i32 | ||
switch i32 %conv, label %sw.default [ | ||
i32 900, label %sw.bb | ||
i32 700, label %sw.bb1 | ||
i32 520, label %sw.bb2 | ||
] | ||
|
||
sw.bb: | ||
br label %return | ||
|
||
sw.bb1: | ||
br label %return | ||
|
||
sw.bb2: | ||
br label %return | ||
|
||
sw.default: | ||
br label %return | ||
|
||
return: | ||
%retval = phi i32 [ 4, %sw.default ], [ 3, %sw.bb2 ], [ 2, %sw.bb1 ], [ 1, %sw.bb ] | ||
ret i32 %retval | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,4 +18,3 @@ define i32 @foo(float %a, float %b) { | |
; SM35: ret i32 350 | ||
ret i32 %reflect | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is overly conservative, IMO. E.g. it may not handle something like
switch(__nvvm_reflect("__CUDA_ARCH"))
, orif ((__nvvm_reflect("__CUDA_ARCH") / 10) == 70)
I think, ideally, we may want to iterate upwards the use tree, as long as the current subtree evaluates to a constant, until we reach a switch/branch/select where we can now pick the correct branch.
On one hand it may be an overkill, and we could just document that this improvement works for
if
only. Even that would implicitly rely on the pass happening very early in the pipeline, before LLVM gets to possibly convert thatif
intoselect
. if-only solution may be OK for now, but if would be great to make sure its behavior is consistent for other forms of conditionals.I'm OK with the patch in this form, but we should add a TODO outlining the still missing pieces.
Maybe add a few test cases showing where it does not work at the moment.