Skip to content

Commit 3d2a976

Browse files
[NVPTX] Handle addrspacecast constant expressions in aggregate initializers
We need to track if an AddrSpaceCast expression was seen when generating an MCExpr for a ConstantExpr. This change introduces a custom lowerConstant method to the NVPTX asm printer that will create NVPTXGenericMCSymbolRefExpr nodes at the appropriate places to encode the information that a given symbol needs to be casted to a generic address. llvm-svn: 236000
1 parent 5e1441b commit 3d2a976

File tree

5 files changed

+270
-2
lines changed

5 files changed

+270
-2
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,212 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) {
19831983
return false;
19841984
}
19851985

1986+
/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1987+
/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1988+
/// expressions that are representable in PTX and create
1989+
/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1990+
const MCExpr *
1991+
NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1992+
MCContext &Ctx = OutContext;
1993+
1994+
if (CV->isNullValue() || isa<UndefValue>(CV))
1995+
return MCConstantExpr::Create(0, Ctx);
1996+
1997+
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1998+
return MCConstantExpr::Create(CI->getZExtValue(), Ctx);
1999+
2000+
if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
2001+
const MCSymbolRefExpr *Expr =
2002+
MCSymbolRefExpr::Create(getSymbol(GV), Ctx);
2003+
if (ProcessingGeneric) {
2004+
return NVPTXGenericMCSymbolRefExpr::Create(Expr, Ctx);
2005+
} else {
2006+
return Expr;
2007+
}
2008+
}
2009+
2010+
const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
2011+
if (!CE) {
2012+
llvm_unreachable("Unknown constant value to lower!");
2013+
}
2014+
2015+
switch (CE->getOpcode()) {
2016+
default:
2017+
// If the code isn't optimized, there may be outstanding folding
2018+
// opportunities. Attempt to fold the expression using DataLayout as a
2019+
// last resort before giving up.
2020+
if (Constant *C = ConstantFoldConstantExpression(CE, *TM.getDataLayout()))
2021+
if (C != CE)
2022+
return lowerConstantForGV(C, ProcessingGeneric);
2023+
2024+
// Otherwise report the problem to the user.
2025+
{
2026+
std::string S;
2027+
raw_string_ostream OS(S);
2028+
OS << "Unsupported expression in static initializer: ";
2029+
CE->printAsOperand(OS, /*PrintType=*/false,
2030+
!MF ? nullptr : MF->getFunction()->getParent());
2031+
report_fatal_error(OS.str());
2032+
}
2033+
2034+
case Instruction::AddrSpaceCast: {
2035+
// Strip the addrspacecast and pass along the operand
2036+
PointerType *DstTy = cast<PointerType>(CE->getType());
2037+
if (DstTy->getAddressSpace() == 0) {
2038+
return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2039+
}
2040+
std::string S;
2041+
raw_string_ostream OS(S);
2042+
OS << "Unsupported expression in static initializer: ";
2043+
CE->printAsOperand(OS, /*PrintType=*/ false,
2044+
!MF ? 0 : MF->getFunction()->getParent());
2045+
report_fatal_error(OS.str());
2046+
}
2047+
2048+
case Instruction::GetElementPtr: {
2049+
const DataLayout &DL = *TM.getDataLayout();
2050+
2051+
// Generate a symbolic expression for the byte address
2052+
APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2053+
cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2054+
2055+
const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2056+
ProcessingGeneric);
2057+
if (!OffsetAI)
2058+
return Base;
2059+
2060+
int64_t Offset = OffsetAI.getSExtValue();
2061+
return MCBinaryExpr::CreateAdd(Base, MCConstantExpr::Create(Offset, Ctx),
2062+
Ctx);
2063+
}
2064+
2065+
case Instruction::Trunc:
2066+
// We emit the value and depend on the assembler to truncate the generated
2067+
// expression properly. This is important for differences between
2068+
// blockaddress labels. Since the two labels are in the same function, it
2069+
// is reasonable to treat their delta as a 32-bit value.
2070+
// FALL THROUGH.
2071+
case Instruction::BitCast:
2072+
return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2073+
2074+
case Instruction::IntToPtr: {
2075+
const DataLayout &DL = *TM.getDataLayout();
2076+
2077+
// Handle casts to pointers by changing them into casts to the appropriate
2078+
// integer type. This promotes constant folding and simplifies this code.
2079+
Constant *Op = CE->getOperand(0);
2080+
Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2081+
false/*ZExt*/);
2082+
return lowerConstantForGV(Op, ProcessingGeneric);
2083+
}
2084+
2085+
case Instruction::PtrToInt: {
2086+
const DataLayout &DL = *TM.getDataLayout();
2087+
2088+
// Support only foldable casts to/from pointers that can be eliminated by
2089+
// changing the pointer to the appropriately sized integer type.
2090+
Constant *Op = CE->getOperand(0);
2091+
Type *Ty = CE->getType();
2092+
2093+
const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2094+
2095+
// We can emit the pointer value into this slot if the slot is an
2096+
// integer slot equal to the size of the pointer.
2097+
if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2098+
return OpExpr;
2099+
2100+
// Otherwise the pointer is smaller than the resultant integer, mask off
2101+
// the high bits so we are sure to get a proper truncation if the input is
2102+
// a constant expr.
2103+
unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2104+
const MCExpr *MaskExpr = MCConstantExpr::Create(~0ULL >> (64-InBits), Ctx);
2105+
return MCBinaryExpr::CreateAnd(OpExpr, MaskExpr, Ctx);
2106+
}
2107+
2108+
// The MC library also has a right-shift operator, but it isn't consistently
2109+
// signed or unsigned between different targets.
2110+
case Instruction::Add: {
2111+
const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2112+
const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2113+
switch (CE->getOpcode()) {
2114+
default: llvm_unreachable("Unknown binary operator constant cast expr");
2115+
case Instruction::Add: return MCBinaryExpr::CreateAdd(LHS, RHS, Ctx);
2116+
}
2117+
}
2118+
}
2119+
}
2120+
2121+
// Copy of MCExpr::print customized for NVPTX
2122+
void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2123+
switch (Expr.getKind()) {
2124+
case MCExpr::Target:
2125+
return cast<MCTargetExpr>(&Expr)->PrintImpl(OS);
2126+
case MCExpr::Constant:
2127+
OS << cast<MCConstantExpr>(Expr).getValue();
2128+
return;
2129+
2130+
case MCExpr::SymbolRef: {
2131+
const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2132+
const MCSymbol &Sym = SRE.getSymbol();
2133+
OS << Sym;
2134+
return;
2135+
}
2136+
2137+
case MCExpr::Unary: {
2138+
const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2139+
switch (UE.getOpcode()) {
2140+
case MCUnaryExpr::LNot: OS << '!'; break;
2141+
case MCUnaryExpr::Minus: OS << '-'; break;
2142+
case MCUnaryExpr::Not: OS << '~'; break;
2143+
case MCUnaryExpr::Plus: OS << '+'; break;
2144+
}
2145+
printMCExpr(*UE.getSubExpr(), OS);
2146+
return;
2147+
}
2148+
2149+
case MCExpr::Binary: {
2150+
const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2151+
2152+
// Only print parens around the LHS if it is non-trivial.
2153+
if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2154+
isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2155+
printMCExpr(*BE.getLHS(), OS);
2156+
} else {
2157+
OS << '(';
2158+
printMCExpr(*BE.getLHS(), OS);
2159+
OS<< ')';
2160+
}
2161+
2162+
switch (BE.getOpcode()) {
2163+
case MCBinaryExpr::Add:
2164+
// Print "X-42" instead of "X+-42".
2165+
if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2166+
if (RHSC->getValue() < 0) {
2167+
OS << RHSC->getValue();
2168+
return;
2169+
}
2170+
}
2171+
2172+
OS << '+';
2173+
break;
2174+
default: llvm_unreachable("Unhandled binary operator");
2175+
}
2176+
2177+
// Only print parens around the LHS if it is non-trivial.
2178+
if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2179+
printMCExpr(*BE.getRHS(), OS);
2180+
} else {
2181+
OS << '(';
2182+
printMCExpr(*BE.getRHS(), OS);
2183+
OS << ')';
2184+
}
2185+
return;
2186+
}
2187+
}
2188+
2189+
llvm_unreachable("Invalid expression kind!");
2190+
}
2191+
19862192
/// PrintAsmOperand - Print out an operand for an inline asm expression.
19872193
///
19882194
bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,10 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
169169
} else {
170170
O << *Name;
171171
}
172-
} else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(v)) {
173-
O << *AP.lowerConstant(Cexpr);
172+
} else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
173+
const MCExpr *Expr =
174+
AP.lowerConstantForGV(cast<Constant>(CExpr), false);
175+
AP.printMCExpr(*Expr, O);
174176
} else
175177
llvm_unreachable("symbol type unknown");
176178
nSym++;
@@ -241,6 +243,10 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
241243
bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo,
242244
unsigned AsmVariant, const char *ExtraCode,
243245
raw_ostream &) override;
246+
247+
const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric);
248+
void printMCExpr(const MCExpr &Expr, raw_ostream &OS);
249+
244250
protected:
245251
bool doInitialization(Module &M) override;
246252
bool doFinalization(Module &M) override;

llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,13 @@ void NVPTXFloatMCExpr::PrintImpl(raw_ostream &OS) const {
4545
OS << std::string(NumHex - HexStr.length(), '0');
4646
OS << utohexstr(API.getZExtValue());
4747
}
48+
49+
const NVPTXGenericMCSymbolRefExpr*
50+
NVPTXGenericMCSymbolRefExpr::Create(const MCSymbolRefExpr *SymExpr,
51+
MCContext &Ctx) {
52+
return new (Ctx) NVPTXGenericMCSymbolRefExpr(SymExpr);
53+
}
54+
55+
void NVPTXGenericMCSymbolRefExpr::PrintImpl(raw_ostream &OS) const {
56+
OS << "generic(" << *SymExpr << ")";
57+
}

llvm/lib/Target/NVPTX/NVPTXMCExpr.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,50 @@ class NVPTXFloatMCExpr : public MCTargetExpr {
7979
return E->getKind() == MCExpr::Target;
8080
}
8181
};
82+
83+
/// A wrapper for MCSymbolRefExpr that tells the assembly printer that the
84+
/// symbol should be enclosed by generic().
85+
class NVPTXGenericMCSymbolRefExpr : public MCTargetExpr {
86+
private:
87+
const MCSymbolRefExpr *SymExpr;
88+
89+
explicit NVPTXGenericMCSymbolRefExpr(const MCSymbolRefExpr *_SymExpr)
90+
: SymExpr(_SymExpr) {}
91+
92+
public:
93+
/// @name Construction
94+
/// @{
95+
96+
static const NVPTXGenericMCSymbolRefExpr
97+
*Create(const MCSymbolRefExpr *SymExpr, MCContext &Ctx);
98+
99+
/// @}
100+
/// @name Accessors
101+
/// @{
102+
103+
/// getOpcode - Get the kind of this expression.
104+
const MCSymbolRefExpr *getSymbolExpr() const { return SymExpr; }
105+
106+
/// @}
107+
108+
void PrintImpl(raw_ostream &OS) const;
109+
bool EvaluateAsRelocatableImpl(MCValue &Res,
110+
const MCAsmLayout *Layout,
111+
const MCFixup *Fixup) const override {
112+
return false;
113+
}
114+
void visitUsedExpr(MCStreamer &Streamer) const override {};
115+
const MCSection *FindAssociatedSection() const override {
116+
return nullptr;
117+
}
118+
119+
// There are no TLS NVPTXMCExprs at the moment.
120+
void fixELFSymbolsInTLSFixups(MCAssembler &Asm) const override {}
121+
122+
static bool classof(const MCExpr *E) {
123+
return E->getKind() == MCExpr::Target;
124+
}
125+
};
82126
} // end namespace llvm
83127

84128
#endif

llvm/test/CodeGen/NVPTX/addrspacecast-gvar.ll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
; CHECK: .visible .global .align 4 .u32 g2 = generic(g);
55
; CHECK: .visible .global .align 4 .u32 g3 = g;
66
; CHECK: .visible .global .align 8 .u32 g4[2] = {0, generic(g)};
7+
; CHECK: .visible .global .align 8 .u32 g5[2] = {0, generic(g)+8};
78

89
@g = addrspace(1) global i32 42
910
@g2 = addrspace(1) global i32* addrspacecast (i32 addrspace(1)* @g to i32*)
1011
@g3 = addrspace(1) global i32 addrspace(1)* @g
1112
@g4 = constant {i32*, i32*} {i32* null, i32* addrspacecast (i32 addrspace(1)* @g to i32*)}
13+
@g5 = constant {i32*, i32*} {i32* null, i32* addrspacecast (i32 addrspace(1)* getelementptr (i32, i32 addrspace(1)* @g, i32 2) to i32*)}

0 commit comments

Comments
 (0)