Skip to content

Commit 92d9b78

Browse files
committed
[NNVMReflect] Force dead branch elimination in NNVMReflect
Summary: The `__nvvm_reflect` function is used to guard invalid code that varies between architectures. One problem with this feature is that if it is used without optimizations, it will leave invalid code in the module that will then make it to the backend. The `__nvvm_reflect` pass is already mandatory, so it should do some trivial branch removal to ensure that constants are handled correctly. This dead branch elimination only works in the trivial case of a compare on a branch and does not touch any conditionals that were not realted to the `__nvvm_reflect` call in order to preserve `O0` semantics as much as possible. This should allow the following to work on NVPTX targets ```c int foo() { if (__nvvm_reflect__("__CUDA_ARCH") >= 700) asm("valid;\n"); } ```
1 parent 347ab99 commit 92d9b78

File tree

3 files changed

+163
-1
lines changed

3 files changed

+163
-1
lines changed

llvm/lib/Target/NVPTX/NVVMReflect.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "NVPTX.h"
2222
#include "llvm/ADT/SmallVector.h"
23+
#include "llvm/Analysis/ConstantFolding.h"
2324
#include "llvm/IR/Constants.h"
2425
#include "llvm/IR/DerivedTypes.h"
2526
#include "llvm/IR/Function.h"
@@ -36,6 +37,8 @@
3637
#include "llvm/Support/raw_os_ostream.h"
3738
#include "llvm/Support/raw_ostream.h"
3839
#include "llvm/Transforms/Scalar.h"
40+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
41+
#include "llvm/Transforms/Utils/Local.h"
3942
#include <sstream>
4043
#include <string>
4144
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
@@ -87,6 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
8790
}
8891

8992
SmallVector<Instruction *, 4> ToRemove;
93+
SmallVector<ICmpInst *, 4> ToSimplify;
9094

9195
// Go through the calls in this function. Each call to __nvvm_reflect or
9296
// llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
@@ -171,13 +175,70 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
171175
} else if (ReflectArg == "__CUDA_ARCH") {
172176
ReflectVal = SmVersion * 10;
173177
}
178+
179+
// If the immediate user is a simple comparison we want to simplify it.
180+
for (User *U : Call->users())
181+
if (ICmpInst *I = dyn_cast<ICmpInst>(U))
182+
ToSimplify.push_back(I);
183+
174184
Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
175185
ToRemove.push_back(Call);
176186
}
177187

178188
for (Instruction *I : ToRemove)
179189
I->eraseFromParent();
180190

191+
// The code guarded by __nvvm_reflect may be invalid for the target machine.
192+
// We need to do some basic dead code elimination to trim invalid code before
193+
// it reaches the backend at all optimization levels.
194+
SmallVector<BranchInst *> Simplified;
195+
for (ICmpInst *Cmp : ToSimplify) {
196+
Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
197+
Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
198+
199+
if (!LHS || !RHS)
200+
continue;
201+
202+
// If the comparison is a compile time constat we sipmly propagate it.
203+
Constant *C = ConstantFoldCompareInstOperands(
204+
Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());
205+
206+
if (!C)
207+
continue;
208+
209+
for (User *U : Cmp->users())
210+
if (BranchInst *I = dyn_cast<BranchInst>(U))
211+
Simplified.push_back(I);
212+
213+
Cmp->replaceAllUsesWith(C);
214+
Cmp->eraseFromParent();
215+
}
216+
217+
// Each instruction here is a conditional branch off of a constant true or
218+
// false value. Simply replace it with an unconditional branch to the
219+
// appropriate basic block and delete the rest if it is trivally dead.
220+
DenseSet<Instruction *> Removed;
221+
for (BranchInst *Branch : Simplified) {
222+
if (Removed.contains(Branch))
223+
continue;
224+
225+
ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition());
226+
if (!C || (!C->isOne() && !C->isZero()))
227+
continue;
228+
229+
BasicBlock *TrueBB =
230+
C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1);
231+
BasicBlock *FalseBB =
232+
C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0);
233+
234+
ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB));
235+
if (FalseBB->use_empty() && FalseBB->hasNPredecessors(0) &&
236+
FalseBB->getFirstNonPHIOrDbg()) {
237+
Removed.insert(FalseBB->getFirstNonPHIOrDbg());
238+
changeToUnreachable(FalseBB->getFirstNonPHIOrDbg());
239+
}
240+
}
241+
181242
return ToRemove.size() > 0;
182243
}
183244

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_52 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_52
2+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_70
3+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx72 -O0 | FileCheck %s --check-prefix=SM_90
4+
5+
@.str = private unnamed_addr constant [12 x i8] c"__CUDA_ARCH\00"
6+
7+
declare i32 @__nvvm_reflect(ptr)
8+
9+
; SM_52: .visible .func (.param .b32 func_retval0) foo()
10+
; SM_52: mov.b32 %[[REG:.+]], 3;
11+
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
12+
; SM_52-NEXT: ret;
13+
;
14+
; SM_70: .visible .func (.param .b32 func_retval0) foo()
15+
; SM_70: mov.b32 %[[REG:.+]], 2;
16+
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
17+
; SM_70-NEXT: ret;
18+
;
19+
; SM_90: .visible .func (.param .b32 func_retval0) foo()
20+
; SM_90: mov.b32 %[[REG:.+]], 1;
21+
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
22+
; SM_90-NEXT: ret;
23+
define i32 @foo() {
24+
entry:
25+
%call = call i32 @__nvvm_reflect(ptr @.str)
26+
%cmp = icmp uge i32 %call, 900
27+
br i1 %cmp, label %if.then, label %if.else
28+
29+
if.then:
30+
br label %return
31+
32+
if.else:
33+
%call1 = call i32 @__nvvm_reflect(ptr @.str)
34+
%cmp2 = icmp uge i32 %call1, 700
35+
br i1 %cmp2, label %if.then3, label %if.else4
36+
37+
if.then3:
38+
br label %return
39+
40+
if.else4:
41+
%call5 = call i32 @__nvvm_reflect(ptr @.str)
42+
%cmp6 = icmp uge i32 %call5, 520
43+
br i1 %cmp6, label %if.then7, label %if.else8
44+
45+
if.then7:
46+
br label %return
47+
48+
if.else8:
49+
br label %return
50+
51+
return:
52+
%retval.0 = phi i32 [ 1, %if.then ], [ 2, %if.then3 ], [ 3, %if.then7 ], [ 4, %if.else8 ]
53+
ret i32 %retval.0
54+
}
55+
56+
; SM_52: .visible .func (.param .b32 func_retval0) bar()
57+
; SM_52: mov.b32 %[[REG:.+]], 2;
58+
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
59+
; SM_52-NEXT: ret;
60+
;
61+
; SM_70: .visible .func (.param .b32 func_retval0) bar()
62+
; SM_70: mov.b32 %[[REG:.+]], 1;
63+
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
64+
; SM_70-NEXT: ret;
65+
;
66+
; SM_90: .visible .func (.param .b32 func_retval0) bar()
67+
; SM_90: mov.b32 %[[REG:.+]], 1;
68+
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
69+
; SM_90-NEXT: ret;
70+
define i32 @bar() {
71+
entry:
72+
%call = call i32 @__nvvm_reflect(ptr @.str)
73+
%cmp = icmp uge i32 %call, 700
74+
br i1 %cmp, label %if.then, label %if.else
75+
76+
if.then:
77+
br label %if.end
78+
79+
if.else:
80+
br label %if.end
81+
82+
if.end:
83+
%x = phi i32 [ 1, %if.then ], [ 2, %if.else ]
84+
ret i32 %x
85+
}
86+
87+
; SM_52-NOT: valid;
88+
; SM_70: valid;
89+
; SM_90: valid;
90+
define void @baz() {
91+
entry:
92+
%call = call i32 @__nvvm_reflect(ptr @.str)
93+
%cmp = icmp uge i32 %call, 700
94+
br i1 %cmp, label %if.then, label %if.end
95+
96+
if.then:
97+
call void asm sideeffect "valid;\0A", ""()
98+
br label %if.end
99+
100+
if.end:
101+
ret void
102+
}

llvm/test/CodeGen/NVPTX/nvvm-reflect-arch.ll

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,3 @@ define i32 @foo(float %a, float %b) {
1818
; SM35: ret i32 350
1919
ret i32 %reflect
2020
}
21-

0 commit comments

Comments
 (0)