Skip to content

Commit 30be471

Browse files
committed
[InferAS] Support getAssumedAddrSpace for Arguments for NVPTX
1 parent adba14a commit 30be471

File tree

4 files changed

+85
-20
lines changed

4 files changed

+85
-20
lines changed

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -678,11 +678,8 @@ static bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F) {
678678

679679
LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
680680
for (Argument &Arg : F.args()) {
681-
if (Arg.getType()->isPointerTy()) {
682-
if (Arg.hasByValAttr())
683-
handleByValParam(TM, &Arg);
684-
else if (TM.getDrvInterface() == NVPTX::CUDA)
685-
markPointerAsGlobal(&Arg);
681+
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
682+
handleByValParam(TM, &Arg);
686683
} else if (Arg.getType()->isIntegerTy() &&
687684
TM.getDrvInterface() == NVPTX::CUDA) {
688685
HandleIntToPtr(Arg);
@@ -699,10 +696,9 @@ static bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F) {
699696
cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
700697

701698
for (Argument &Arg : F.args())
702-
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
703-
markPointerAsAS(&Arg, ADDRESS_SPACE_LOCAL);
699+
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
704700
adjustByValArgAlignment(&Arg, &Arg, TLI);
705-
}
701+
706702
return true;
707703
}
708704

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,21 @@ unsigned NVPTXTTIImpl::getAssumedAddrSpace(const Value *V) const {
599599
if (isa<AllocaInst>(V))
600600
return ADDRESS_SPACE_LOCAL;
601601

602+
if (const Argument *Arg = dyn_cast<Argument>(V)) {
603+
if (isKernelFunction(*Arg->getParent())) {
604+
const NVPTXTargetMachine &TM =
605+
static_cast<const NVPTXTargetMachine &>(getTLI()->getTargetMachine());
606+
if (TM.getDrvInterface() == NVPTX::CUDA && !Arg->hasByValAttr())
607+
return ADDRESS_SPACE_GLOBAL;
608+
} else {
609+
// We assume that all device parameters that are passed byval will be
610+
// placed in the local AS. Very simple cases will be updated after ISel to
611+
// use the device param space where possible.
612+
if (Arg->hasByValAttr())
613+
return ADDRESS_SPACE_LOCAL;
614+
}
615+
}
616+
602617
return -1;
603618
}
604619

llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,15 @@ static bool isNoopPtrIntCastPair(const Operator *I2P, const DataLayout &DL,
305305
}
306306

307307
// Returns true if V is an address expression.
308-
// TODO: Currently, we consider only phi, bitcast, addrspacecast, and
309-
// getelementptr operators.
308+
// TODO: Currently, we consider only arguments and phi, bitcast, addrspacecast,
309+
// and getelementptr operators.
310310
static bool isAddressExpression(const Value &V, const DataLayout &DL,
311311
const TargetTransformInfo *TTI) {
312+
313+
if (const Argument *Arg = dyn_cast<Argument>(&V))
314+
return Arg->getType()->isPointerTy() &&
315+
TTI->getAssumedAddrSpace(&V) != UninitializedAddressSpace;
316+
312317
const Operator *Op = dyn_cast<Operator>(&V);
313318
if (!Op)
314319
return false;
@@ -341,6 +346,9 @@ static bool isAddressExpression(const Value &V, const DataLayout &DL,
341346
static SmallVector<Value *, 2>
342347
getPointerOperands(const Value &V, const DataLayout &DL,
343348
const TargetTransformInfo *TTI) {
349+
if (isa<Argument>(&V))
350+
return {};
351+
344352
const Operator &Op = cast<Operator>(V);
345353
switch (Op.getOpcode()) {
346354
case Instruction::PHI: {
@@ -505,13 +513,11 @@ void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack(
505513
if (Visited.insert(V).second) {
506514
PostorderStack.emplace_back(V, false);
507515

508-
Operator *Op = cast<Operator>(V);
509-
for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) {
510-
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op->getOperand(I))) {
511-
if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second)
512-
PostorderStack.emplace_back(CE, false);
513-
}
514-
}
516+
if (auto *Op = dyn_cast<Operator>(V))
517+
for (auto &O : Op->operands())
518+
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(O))
519+
if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second)
520+
PostorderStack.emplace_back(CE, false);
515521
}
516522
}
517523
}
@@ -828,6 +834,18 @@ Value *InferAddressSpacesImpl::cloneValueWithNewAddressSpace(
828834
assert(V->getType()->getPointerAddressSpace() == FlatAddrSpace &&
829835
isAddressExpression(*V, *DL, TTI));
830836

837+
if (auto *Arg = dyn_cast<Argument>(V)) {
838+
// Arguments are address space casted in the function body, as we do not
839+
// want to change the function signature.
840+
Function *F = Arg->getParent();
841+
BasicBlock::iterator Insert = F->getEntryBlock().getFirstNonPHIIt();
842+
843+
Type *NewPtrTy = PointerType::get(Arg->getContext(), NewAddrSpace);
844+
auto *NewI = new AddrSpaceCastInst(Arg, NewPtrTy);
845+
NewI->insertBefore(Insert);
846+
return NewI;
847+
}
848+
831849
if (Instruction *I = dyn_cast<Instruction>(V)) {
832850
Value *NewV = cloneInstructionWithNewAddressSpace(
833851
I, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, PoisonUsesToFix);
@@ -966,8 +984,9 @@ bool InferAddressSpacesImpl::updateAddressSpace(
966984
// of all its pointer operands.
967985
unsigned NewAS = UninitializedAddressSpace;
968986

969-
const Operator &Op = cast<Operator>(V);
970-
if (Op.getOpcode() == Instruction::Select) {
987+
if (isa<Operator>(V) &&
988+
cast<Operator>(V).getOpcode() == Instruction::Select) {
989+
const Operator &Op = cast<Operator>(V);
971990
Value *Src0 = Op.getOperand(1);
972991
Value *Src1 = Op.getOperand(2);
973992

@@ -1275,7 +1294,7 @@ void InferAddressSpacesImpl::performPointerReplacement(
12751294
// This instruction may contain multiple uses of V, update them all.
12761295
CurUser->replaceUsesOfWith(
12771296
V, new AddrSpaceCastInst(NewV, V->getType(), "", InsertPos));
1278-
} else {
1297+
} else if (isa<Constant>(V)) {
12791298
CurUserI->replaceUsesOfWith(
12801299
V, ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), V->getType()));
12811300
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -passes=infer-address-spaces %s | FileCheck %s
3+
4+
target triple = "nvptx64-nvidia-cuda"
5+
6+
7+
define ptx_kernel i32 @test_kernel(ptr %a, ptr byval(i32) %b) {
8+
; CHECK-LABEL: define ptx_kernel i32 @test_kernel(
9+
; CHECK-SAME: ptr [[A:%.*]], ptr byval(i32) [[B:%.*]]) {
10+
; CHECK-NEXT: [[TMP1:%.*]] = addrspacecast ptr [[A]] to ptr addrspace(1)
11+
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr addrspace(1) [[TMP1]], align 4
12+
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[B]], align 4
13+
; CHECK-NEXT: [[SUM:%.*]] = add i32 [[V1]], [[V2]]
14+
; CHECK-NEXT: ret i32 [[SUM]]
15+
;
16+
%v1 = load i32, ptr %a
17+
%v2 = load i32, ptr %b
18+
%sum = add i32 %v1, %v2
19+
ret i32 %sum
20+
}
21+
22+
define i32 @test_device(ptr %a, ptr byval(i32) %b) {
23+
; CHECK-LABEL: define i32 @test_device(
24+
; CHECK-SAME: ptr [[A:%.*]], ptr byval(i32) [[B:%.*]]) {
25+
; CHECK-NEXT: [[TMP1:%.*]] = addrspacecast ptr [[B]] to ptr addrspace(5)
26+
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[A]], align 4
27+
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr addrspace(5) [[TMP1]], align 4
28+
; CHECK-NEXT: [[SUM:%.*]] = add i32 [[V1]], [[V2]]
29+
; CHECK-NEXT: ret i32 [[SUM]]
30+
;
31+
%v1 = load i32, ptr %a
32+
%v2 = load i32, ptr %b
33+
%sum = add i32 %v1, %v2
34+
ret i32 %sum
35+
}

0 commit comments

Comments
 (0)