Skip to content

Commit 507c67d

Browse files
committed
[NVPTX] Sync generation of parameter names in a function signature with the function body.
This fixes parameter names mismatch in anonymous functions. Differential Revision: https://reviews.llvm.org/D144407
1 parent 9aae408 commit 507c67d

File tree

5 files changed

+60
-42
lines changed

5 files changed

+60
-42
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ void NVPTXAsmPrinter::emitFunctionEntryLabel() {
466466

467467
CurrentFnSym->print(O, MAI);
468468

469-
emitFunctionParamList(*MF, O);
469+
emitFunctionParamList(F, O);
470470

471471
if (isKernelFunction(*F))
472472
emitKernelFunctionDirectives(*F, O);
@@ -1441,12 +1441,6 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
14411441
}
14421442
}
14431443

1444-
void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1445-
int paramIndex, raw_ostream &O) {
1446-
getSymbol(I->getParent())->print(O, MAI);
1447-
O << "_param_" << paramIndex;
1448-
}
1449-
14501444
void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
14511445
const DataLayout &DL = getDataLayout();
14521446
const AttributeList &PAL = F->getAttributes();
@@ -1485,24 +1479,21 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
14851479
O << "\t.param .u64 .ptr .surfref ";
14861480
else
14871481
O << "\t.param .surfref ";
1488-
CurrentFnSym->print(O, MAI);
1489-
O << "_param_" << paramIndex;
1482+
O << TLI->getParamName(F, paramIndex);
14901483
}
14911484
else { // Default image is read_only
14921485
if (hasImageHandles)
14931486
O << "\t.param .u64 .ptr .texref ";
14941487
else
14951488
O << "\t.param .texref ";
1496-
CurrentFnSym->print(O, MAI);
1497-
O << "_param_" << paramIndex;
1489+
O << TLI->getParamName(F, paramIndex);
14981490
}
14991491
} else {
15001492
if (hasImageHandles)
15011493
O << "\t.param .u64 .ptr .samplerref ";
15021494
else
15031495
O << "\t.param .samplerref ";
1504-
CurrentFnSym->print(O, MAI);
1505-
O << "_param_" << paramIndex;
1496+
O << TLI->getParamName(F, paramIndex);
15061497
}
15071498
continue;
15081499
}
@@ -1524,7 +1515,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
15241515
Align OptimalAlign = getOptimalAlignForParam(Ty);
15251516

15261517
O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1527-
printParamName(I, paramIndex, O);
1518+
O << TLI->getParamName(F, paramIndex);
15281519
O << "[" << DL.getTypeAllocSize(Ty) << "]";
15291520

15301521
continue;
@@ -1563,7 +1554,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
15631554
Align ParamAlign = I->getParamAlign().valueOrOne();
15641555
O << ".align " << ParamAlign.value() << " ";
15651556
}
1566-
printParamName(I, paramIndex, O);
1557+
O << TLI->getParamName(F, paramIndex);
15671558
continue;
15681559
}
15691560

@@ -1575,7 +1566,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
15751566
else
15761567
O << getPTXFundamentalTypeStr(Ty);
15771568
O << " ";
1578-
printParamName(I, paramIndex, O);
1569+
O << TLI->getParamName(F, paramIndex);
15791570
continue;
15801571
}
15811572
// Non-kernel function, just print .param .b<size> for ABI
@@ -1598,7 +1589,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
15981589
O << "\t.param .b" << sz << " ";
15991590
else
16001591
O << "\t.reg .b" << sz << " ";
1601-
printParamName(I, paramIndex, O);
1592+
O << TLI->getParamName(F, paramIndex);
16021593
continue;
16031594
}
16041595

@@ -1619,7 +1610,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
16191610

16201611
unsigned sz = DL.getTypeAllocSize(ETy);
16211612
O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1622-
printParamName(I, paramIndex, O);
1613+
O << TLI->getParamName(F, paramIndex);
16231614
O << "[" << sz << "]";
16241615
continue;
16251616
} else {
@@ -1642,7 +1633,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
16421633
if (elemtype.isInteger())
16431634
sz = promoteScalarArgumentSize(sz);
16441635
O << "\t.reg .b" << sz << " ";
1645-
printParamName(I, paramIndex, O);
1636+
O << TLI->getParamName(F, paramIndex);
16461637
if (j < je - 1)
16471638
O << ",\n";
16481639
++paramIndex;
@@ -1660,19 +1651,12 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
16601651
O << ",\n";
16611652
O << "\t.param .align " << STI.getMaxRequiredAlignment();
16621653
O << " .b8 ";
1663-
getSymbol(F)->print(O, MAI);
1664-
O << "_vararg[]";
1654+
O << TLI->getParamName(F, /* vararg */ -1) << "[]";
16651655
}
16661656

16671657
O << "\n)\n";
16681658
}
16691659

1670-
void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1671-
raw_ostream &O) {
1672-
const Function &F = MF.getFunction();
1673-
emitFunctionParamList(&F, O);
1674-
}
1675-
16761660
void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
16771661
const MachineFunction &MF) {
16781662
SmallString<128> Str;

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,11 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
173173
const char *Modifier = nullptr);
174174
void printModuleLevelGV(const GlobalVariable *GVar, raw_ostream &O,
175175
bool processDemoted, const NVPTXSubtarget &STI);
176-
void printParamName(Function::const_arg_iterator I, int paramIndex,
177-
raw_ostream &O);
178176
void emitGlobals(const Module &M);
179177
void emitHeader(Module &M, raw_ostream &O, const NVPTXSubtarget &STI);
180178
void emitKernelFunctionDirectives(const Function &F, raw_ostream &O) const;
181179
void emitVirtualRegister(unsigned int vr, raw_ostream &);
182180
void emitFunctionParamList(const Function *, raw_ostream &O);
183-
void emitFunctionParamList(const MachineFunction &MF, raw_ostream &O);
184181
void setAndEmitFunctionVirtualRegisters(const MachineFunction &MF);
185182
void printReturnValStr(const Function *, raw_ostream &O);
186183
void printReturnValStr(const MachineFunction &MF, raw_ostream &O);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2614,18 +2614,8 @@ SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
26142614
// passing variable arguments.
26152615
SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
26162616
EVT v) const {
2617-
std::string ParamSym;
2618-
raw_string_ostream ParamStr(ParamSym);
2619-
2620-
ParamStr << DAG.getMachineFunction().getName();
2621-
2622-
if (idx < 0)
2623-
ParamStr << "_vararg";
2624-
else
2625-
ParamStr << "_param_" << idx;
2626-
2627-
StringRef SavedStr =
2628-
nvTM->getStrPool().save(ParamSym);
2617+
StringRef SavedStr = nvTM->getStrPool().save(
2618+
getParamName(&DAG.getMachineFunction().getFunction(), idx));
26292619
return DAG.getTargetExternalSymbol(SavedStr.data(), v);
26302620
}
26312621

@@ -4522,6 +4512,23 @@ Align NVPTXTargetLowering::getFunctionByValParamAlign(
45224512
return ArgAlign;
45234513
}
45244514

4515+
// Helper for getting a function parameter name. Name is composed from
4516+
// its index and the function name. Negative index corresponds to special
4517+
// parameter (unsized array) used for passing variable arguments.
4518+
std::string NVPTXTargetLowering::getParamName(const Function *F,
4519+
int Idx) const {
4520+
std::string ParamName;
4521+
raw_string_ostream ParamStr(ParamName);
4522+
4523+
ParamStr << getTargetMachine().getSymbol(F)->getName();
4524+
if (Idx < 0)
4525+
ParamStr << "_vararg";
4526+
else
4527+
ParamStr << "_param_" << Idx;
4528+
4529+
return ParamName;
4530+
}
4531+
45254532
/// isLegalAddressingMode - Return true if the addressing mode represented
45264533
/// by AM is legal for this target, for a load/store of the specified type.
45274534
/// Used to guide target specific optimizations, like loop strength reduction

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,11 @@ class NVPTXTargetLowering : public TargetLowering {
466466
Align InitialAlign,
467467
const DataLayout &DL) const;
468468

469+
// Helper for getting a function parameter name. Name is composed from
470+
// its index and the function name. Negative index corresponds to special
471+
// parameter (unsized array) used for passing variable arguments.
472+
std::string getParamName(const Function *F, int Idx) const;
473+
469474
/// isLegalAddressingMode - Return true if the addressing mode represented
470475
/// by AM is legal for this target, for a load/store of the specified type
471476
/// Used to guide target specific optimizations, like loop strength
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s
2+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 | %ptxas-verify %}
3+
4+
; Check that parameter names we generate in the function signature and the name
5+
; we use when we refer to the parameter in the function body do match.
6+
7+
; CHECK: .func (.param .b32 func_retval0) __unnamed_1(
8+
; CHECK-NEXT: .param .b32 __unnamed_1_param_0
9+
; CHECK: ld.param.u32 {{%r[0-9]+}}, [__unnamed_1_param_0];
10+
11+
define internal i32 @0(i32 %a) {
12+
entry:
13+
%r = add i32 %a, 1
14+
ret i32 %r
15+
}
16+
17+
; CHECK: .func (.param .b32 func_retval0) __unnamed_2(
18+
; CHECK-NEXT: .param .b32 __unnamed_2_param_0
19+
; CHECK: ld.param.u32 {{%r[0-9]+}}, [__unnamed_2_param_0];
20+
21+
define internal i32 @1(i32 %a) {
22+
entry:
23+
%r = add i32 %a, 1
24+
ret i32 %r
25+
}

0 commit comments

Comments
 (0)