|
20 | 20 |
|
21 | 21 | #include "NVPTX.h"
|
22 | 22 | #include "llvm/ADT/SmallVector.h"
|
| 23 | +#include "llvm/Analysis/ConstantFolding.h" |
23 | 24 | #include "llvm/IR/Constants.h"
|
24 | 25 | #include "llvm/IR/DerivedTypes.h"
|
25 | 26 | #include "llvm/IR/Function.h"
|
|
36 | 37 | #include "llvm/Support/raw_os_ostream.h"
|
37 | 38 | #include "llvm/Support/raw_ostream.h"
|
38 | 39 | #include "llvm/Transforms/Scalar.h"
|
| 40 | +#include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 41 | +#include "llvm/Transforms/Utils/Local.h" |
39 | 42 | #include <sstream>
|
40 | 43 | #include <string>
|
41 | 44 | #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
|
@@ -87,6 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
|
87 | 90 | }
|
88 | 91 |
|
89 | 92 | SmallVector<Instruction *, 4> ToRemove;
|
| 93 | + SmallVector<ICmpInst *, 4> ToSimplify; |
90 | 94 |
|
91 | 95 | // Go through the calls in this function. Each call to __nvvm_reflect or
|
92 | 96 | // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
|
@@ -171,13 +175,70 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
|
171 | 175 | } else if (ReflectArg == "__CUDA_ARCH") {
|
172 | 176 | ReflectVal = SmVersion * 10;
|
173 | 177 | }
|
| 178 | + |
| 179 | + // If the immediate user is a simple comparison we want to simplify it. |
| 180 | + for (User *U : Call->users()) |
| 181 | + if (ICmpInst *I = dyn_cast<ICmpInst>(U)) |
| 182 | + ToSimplify.push_back(I); |
| 183 | + |
174 | 184 | Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
|
175 | 185 | ToRemove.push_back(Call);
|
176 | 186 | }
|
177 | 187 |
|
178 | 188 | for (Instruction *I : ToRemove)
|
179 | 189 | I->eraseFromParent();
|
180 | 190 |
|
| 191 | + // The code guarded by __nvvm_reflect may be invalid for the target machine. |
| 192 | + // We need to do some basic dead code elimination to trim invalid code before |
| 193 | + // it reaches the backend at all optimization levels. |
| 194 | + SmallVector<BranchInst *> Simplified; |
| 195 | + for (ICmpInst *Cmp : ToSimplify) { |
| 196 | + Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0)); |
| 197 | + Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1)); |
| 198 | + |
| 199 | + if (!LHS || !RHS) |
| 200 | + continue; |
| 201 | + |
| 202 | + // If the comparison is a compile time constat we sipmly propagate it. |
| 203 | + Constant *C = ConstantFoldCompareInstOperands( |
| 204 | + Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout()); |
| 205 | + |
| 206 | + if (!C) |
| 207 | + continue; |
| 208 | + |
| 209 | + for (User *U : Cmp->users()) |
| 210 | + if (BranchInst *I = dyn_cast<BranchInst>(U)) |
| 211 | + Simplified.push_back(I); |
| 212 | + |
| 213 | + Cmp->replaceAllUsesWith(C); |
| 214 | + Cmp->eraseFromParent(); |
| 215 | + } |
| 216 | + |
| 217 | + // Each instruction here is a conditional branch off of a constant true or |
| 218 | + // false value. Simply replace it with an unconditional branch to the |
| 219 | + // appropriate basic block and delete the rest if it is trivally dead. |
| 220 | + DenseSet<Instruction *> Removed; |
| 221 | + for (BranchInst *Branch : Simplified) { |
| 222 | + if (Removed.contains(Branch)) |
| 223 | + continue; |
| 224 | + |
| 225 | + ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition()); |
| 226 | + if (!C || (!C->isOne() && !C->isZero())) |
| 227 | + continue; |
| 228 | + |
| 229 | + BasicBlock *TrueBB = |
| 230 | + C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1); |
| 231 | + BasicBlock *FalseBB = |
| 232 | + C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0); |
| 233 | + |
| 234 | + ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB)); |
| 235 | + if (FalseBB->use_empty() && FalseBB->hasNPredecessors(0) && |
| 236 | + FalseBB->getFirstNonPHIOrDbg()) { |
| 237 | + Removed.insert(FalseBB->getFirstNonPHIOrDbg()); |
| 238 | + changeToUnreachable(FalseBB->getFirstNonPHIOrDbg()); |
| 239 | + } |
| 240 | + } |
| 241 | + |
181 | 242 | return ToRemove.size() > 0;
|
182 | 243 | }
|
183 | 244 |
|
|
0 commit comments