Skip to content

Commit ffabcbc

Browse files
committed
[NVVMReflect][Reland] Force dead branch elimination in NVVMReflect (#81189)
Summary: The `__nvvm_reflect` function is used to guard invalid code that varies between architectures. One problem with this feature is that if it is used without optimizations, it will leave invalid code in the module that will then make it to the backend. The `__nvvm_reflect` pass is already mandatory, so it should do some trivial branch removal to ensure that constants are handled correctly. This dead branch elimination only works in the trivial case of a compare on a branch and does not touch any conditionals that were not realted to the `__nvvm_reflect` call in order to preserve `O0` semantics as much as possible. This should allow the following to work on NVPTX targets ```c int foo() { if (__nvvm_reflect("__CUDA_ARCH") >= 700) asm("valid;\n"); } ``` Relanding after fixing a bug.
1 parent f608269 commit ffabcbc

File tree

4 files changed

+245
-1
lines changed

4 files changed

+245
-1
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,11 @@ pipeline, immediately after the link stage. The ``internalize`` pass is also
296296
recommended to remove unused math functions from the resulting PTX. For an
297297
input IR module ``module.bc``, the following compilation flow is recommended:
298298

299+
The ``NVVMReflect`` pass will attempt to remove dead code even without
300+
optimizations. This allows potentially incompatible instructions to be avoided
301+
at all optimizations levels. This currently only works for simple conditionals
302+
like the above example.
303+
299304
1. Save list of external functions in ``module.bc``
300305
2. Link ``module.bc`` with ``libdevice.compute_XX.YY.bc``
301306
3. Internalize all functions not in list from (1)

llvm/lib/Target/NVPTX/NVVMReflect.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "NVPTX.h"
2222
#include "llvm/ADT/SmallVector.h"
23+
#include "llvm/Analysis/ConstantFolding.h"
2324
#include "llvm/IR/Constants.h"
2425
#include "llvm/IR/DerivedTypes.h"
2526
#include "llvm/IR/Function.h"
@@ -36,6 +37,8 @@
3637
#include "llvm/Support/raw_os_ostream.h"
3738
#include "llvm/Support/raw_ostream.h"
3839
#include "llvm/Transforms/Scalar.h"
40+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
41+
#include "llvm/Transforms/Utils/Local.h"
3942
#include <sstream>
4043
#include <string>
4144
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
@@ -87,6 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
8790
}
8891

8992
SmallVector<Instruction *, 4> ToRemove;
93+
SmallVector<ICmpInst *, 4> ToSimplify;
9094

9195
// Go through the calls in this function. Each call to __nvvm_reflect or
9296
// llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
@@ -171,13 +175,74 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
171175
} else if (ReflectArg == "__CUDA_ARCH") {
172176
ReflectVal = SmVersion * 10;
173177
}
178+
179+
// If the immediate user is a simple comparison we want to simplify it.
180+
// TODO: This currently does not handle switch instructions.
181+
for (User *U : Call->users())
182+
if (ICmpInst *I = dyn_cast<ICmpInst>(U))
183+
ToSimplify.push_back(I);
184+
174185
Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
175186
ToRemove.push_back(Call);
176187
}
177188

178189
for (Instruction *I : ToRemove)
179190
I->eraseFromParent();
180191

192+
// The code guarded by __nvvm_reflect may be invalid for the target machine.
193+
// We need to do some basic dead code elimination to trim invalid code before
194+
// it reaches the backend at all optimization levels.
195+
SmallVector<BranchInst *> Simplified;
196+
for (ICmpInst *Cmp : ToSimplify) {
197+
Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
198+
Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
199+
200+
if (!LHS || !RHS)
201+
continue;
202+
203+
// If the comparison is a compile time constant we simply propagate it.
204+
Constant *C = ConstantFoldCompareInstOperands(
205+
Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());
206+
207+
if (!C)
208+
continue;
209+
210+
for (User *U : Cmp->users())
211+
if (BranchInst *I = dyn_cast<BranchInst>(U))
212+
Simplified.push_back(I);
213+
214+
Cmp->replaceAllUsesWith(C);
215+
Cmp->eraseFromParent();
216+
}
217+
218+
// Each instruction here is a conditional branch off of a constant true or
219+
// false value. Simply replace it with an unconditional branch to the
220+
// appropriate basic block and delete the rest if it is trivially dead.
221+
DenseSet<Instruction *> Removed;
222+
for (BranchInst *Branch : Simplified) {
223+
if (Removed.contains(Branch))
224+
continue;
225+
226+
ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition());
227+
if (!C || (!C->isOne() && !C->isZero()))
228+
continue;
229+
230+
BasicBlock *TrueBB =
231+
C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1);
232+
BasicBlock *FalseBB =
233+
C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0);
234+
235+
// This transformation is only correct on simple edges.
236+
if (!FalseBB->hasNPredecessors(1))
237+
continue;
238+
239+
ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB));
240+
if (FalseBB->use_empty() && !FalseBB->getFirstNonPHIOrDbg()) {
241+
Removed.insert(FalseBB->getFirstNonPHIOrDbg());
242+
changeToUnreachable(FalseBB->getFirstNonPHIOrDbg());
243+
}
244+
}
245+
181246
return ToRemove.size() > 0;
182247
}
183248

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_52 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_52
2+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_70
3+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx72 -O0 | FileCheck %s --check-prefix=SM_90
4+
5+
@.str = private unnamed_addr constant [12 x i8] c"__CUDA_ARCH\00"
6+
@.str1 = constant [11 x i8] c"__CUDA_FTZ\00"
7+
8+
declare i32 @__nvvm_reflect(ptr)
9+
10+
; SM_52: .visible .func (.param .b32 func_retval0) foo()
11+
; SM_52: mov.b32 %[[REG:.+]], 3;
12+
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
13+
; SM_52-NEXT: ret;
14+
;
15+
; SM_70: .visible .func (.param .b32 func_retval0) foo()
16+
; SM_70: mov.b32 %[[REG:.+]], 2;
17+
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
18+
; SM_70-NEXT: ret;
19+
;
20+
; SM_90: .visible .func (.param .b32 func_retval0) foo()
21+
; SM_90: mov.b32 %[[REG:.+]], 1;
22+
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
23+
; SM_90-NEXT: ret;
24+
define i32 @foo() {
25+
entry:
26+
%call = call i32 @__nvvm_reflect(ptr @.str)
27+
%cmp = icmp uge i32 %call, 900
28+
br i1 %cmp, label %if.then, label %if.else
29+
30+
if.then:
31+
br label %return
32+
33+
if.else:
34+
%call1 = call i32 @__nvvm_reflect(ptr @.str)
35+
%cmp2 = icmp uge i32 %call1, 700
36+
br i1 %cmp2, label %if.then3, label %if.else4
37+
38+
if.then3:
39+
br label %return
40+
41+
if.else4:
42+
%call5 = call i32 @__nvvm_reflect(ptr @.str)
43+
%cmp6 = icmp uge i32 %call5, 520
44+
br i1 %cmp6, label %if.then7, label %if.else8
45+
46+
if.then7:
47+
br label %return
48+
49+
if.else8:
50+
br label %return
51+
52+
return:
53+
%retval.0 = phi i32 [ 1, %if.then ], [ 2, %if.then3 ], [ 3, %if.then7 ], [ 4, %if.else8 ]
54+
ret i32 %retval.0
55+
}
56+
57+
; SM_52: .visible .func (.param .b32 func_retval0) bar()
58+
; SM_52: mov.b32 %[[REG:.+]], 2;
59+
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
60+
; SM_52-NEXT: ret;
61+
;
62+
; SM_70: .visible .func (.param .b32 func_retval0) bar()
63+
; SM_70: mov.b32 %[[REG:.+]], 1;
64+
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
65+
; SM_70-NEXT: ret;
66+
;
67+
; SM_90: .visible .func (.param .b32 func_retval0) bar()
68+
; SM_90: mov.b32 %[[REG:.+]], 1;
69+
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
70+
; SM_90-NEXT: ret;
71+
define i32 @bar() {
72+
entry:
73+
%call = call i32 @__nvvm_reflect(ptr @.str)
74+
%cmp = icmp uge i32 %call, 700
75+
br i1 %cmp, label %if.then, label %if.else
76+
77+
if.then:
78+
br label %if.end
79+
80+
if.else:
81+
br label %if.end
82+
83+
if.end:
84+
%x = phi i32 [ 1, %if.then ], [ 2, %if.else ]
85+
ret i32 %x
86+
}
87+
88+
; SM_52-NOT: valid;
89+
; SM_70: valid;
90+
; SM_90: valid;
91+
define void @baz() {
92+
entry:
93+
%call = call i32 @__nvvm_reflect(ptr @.str)
94+
%cmp = icmp uge i32 %call, 700
95+
br i1 %cmp, label %if.then, label %if.end
96+
97+
if.then:
98+
call void asm sideeffect "valid;\0A", ""()
99+
br label %if.end
100+
101+
if.end:
102+
ret void
103+
}
104+
105+
; SM_52: .visible .func (.param .b32 func_retval0) qux()
106+
; SM_52: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
107+
; SM_52: st.param.b32 [func_retval0+0], %[[REG1:.+]];
108+
; SM_52: ret;
109+
; SM_70: .visible .func (.param .b32 func_retval0) qux()
110+
; SM_70: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
111+
; SM_70: st.param.b32 [func_retval0+0], %[[REG1:.+]];
112+
; SM_70: ret;
113+
; SM_90: .visible .func (.param .b32 func_retval0) qux()
114+
; SM_90: st.param.b32 [func_retval0+0], %[[REG1:.+]];
115+
; SM_90: ret;
116+
define i32 @qux() {
117+
entry:
118+
%call = call i32 @__nvvm_reflect(ptr noundef @.str)
119+
%cmp = icmp uge i32 %call, 700
120+
%conv = zext i1 %cmp to i32
121+
switch i32 %conv, label %sw.default [
122+
i32 900, label %sw.bb
123+
i32 700, label %sw.bb1
124+
i32 520, label %sw.bb2
125+
]
126+
127+
sw.bb:
128+
br label %return
129+
130+
sw.bb1:
131+
br label %return
132+
133+
sw.bb2:
134+
br label %return
135+
136+
sw.default:
137+
br label %return
138+
139+
return:
140+
%retval = phi i32 [ 4, %sw.default ], [ 3, %sw.bb2 ], [ 2, %sw.bb1 ], [ 1, %sw.bb ]
141+
ret i32 %retval
142+
}
143+
144+
; SM_52: .visible .func (.param .b32 func_retval0) phi()
145+
; SM_52: mov.f32 %[[REG:.+]], 0f00000000;
146+
; SM_52-NEXT: st.param.f32 [func_retval0+0], %[[REG]];
147+
; SM_52-NEXT: ret;
148+
; SM_70: .visible .func (.param .b32 func_retval0) phi()
149+
; SM_70: mov.f32 %[[REG:.+]], 0f00000000;
150+
; SM_70-NEXT: st.param.f32 [func_retval0+0], %[[REG]];
151+
; SM_70-NEXT: ret;
152+
; SM_90: .visible .func (.param .b32 func_retval0) phi()
153+
; SM_90: mov.f32 %[[REG:.+]], 0f00000000;
154+
; SM_90-NEXT: st.param.f32 [func_retval0+0], %[[REG]];
155+
; SM_90-NEXT: ret;
156+
define float @phi() {
157+
entry:
158+
%0 = call i32 @__nvvm_reflect(ptr @.str)
159+
%1 = icmp eq i32 %0, 0
160+
br i1 %1, label %if.then, label %if.else
161+
162+
if.then:
163+
br label %if.else
164+
165+
if.else:
166+
%.08 = phi float [ 0.000000e+00, %if.then ], [ 1.000000e+00, %entry ]
167+
%4 = fcmp ogt float %.08, 0.000000e+00
168+
br i1 %4, label %exit, label %if.exit
169+
170+
if.exit:
171+
br label %exit
172+
173+
exit:
174+
ret float 0.000000e+00
175+
}

llvm/test/CodeGen/NVPTX/nvvm-reflect-arch.ll

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,3 @@ define i32 @foo(float %a, float %b) {
1818
; SM35: ret i32 350
1919
ret i32 %reflect
2020
}
21-

0 commit comments

Comments
 (0)