Skip to content

[X86] X86FixupVectorConstants - load+sign-extend vector constants that can be stored in a truncated form #79815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions llvm/lib/Target/X86/X86FixupVectorConstants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// replace them with smaller constant pool entries, including:
// * Converting AVX512 memory-fold instructions to their broadcast-fold form
// * Broadcasting of full width loads.
// * TODO: Sign/Zero extension of full width loads.
// * TODO: Zero extension of full width loads.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -265,11 +265,47 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
return nullptr;
}

static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
unsigned SrcEltBitWidth) {
Type *Ty = C->getType();
unsigned NumBits = Ty->getPrimitiveSizeInBits();
unsigned DstEltBitWidth = NumBits / NumElts;
assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
(DstEltBitWidth % SrcEltBitWidth) == 0 &&
(DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");

if (std::optional<APInt> Bits = extractConstantBits(C)) {
assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
(Bits->getBitWidth() % DstEltBitWidth) == 0 &&
"Unexpected constant extension");

// Ensure every vector element can be represented by the src bitwidth.
APInt TruncBits = APInt::getZero(NumElts * SrcEltBitWidth);
for (unsigned I = 0; I != NumElts; ++I) {
APInt Elt = Bits->extractBits(DstEltBitWidth, I * DstEltBitWidth);
if ((IsSExt && Elt.getSignificantBits() > SrcEltBitWidth) ||
(!IsSExt && Elt.getActiveBits() > SrcEltBitWidth))
return nullptr;
TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
}

return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
SrcEltBitWidth);
}

return nullptr;
}
static Constant *rebuildSExtCst(const Constant *C, unsigned NumElts,
unsigned SrcEltBitWidth) {
return rebuildExtCst(C, true, NumElts, SrcEltBitWidth);
}

bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
MachineBasicBlock &MBB,
MachineInstr &MI) {
unsigned Opc = MI.getOpcode();
MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
bool HasSSE41 = ST->hasSSE41();
bool HasAVX2 = ST->hasAVX2();
bool HasDQI = ST->hasDQI();
bool HasBWI = ST->hasBWI();
Expand Down Expand Up @@ -312,7 +348,15 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
return false;
};

// Attempt to convert full width vector loads into broadcast/vzload loads.
// Attempt to detect a suitable vzload/broadcast/vextload from increasing
// constant bitwidths. Prefer vzload/broadcast/vextload for same bitwidth:
// - vzload shouldn't ever need a shuffle port to zero the upper elements and
// the fp/int domain versions are equally available so we don't introduce a
// domain crossing penalty.
// - broadcast sometimes need a shuffle port (especially for 8/16-bit
// variants), AVX1 only has fp domain broadcasts but AVX2+ have good fp/int
// domain equivalents.
// - vextload always needs a shuffle port and is only ever int domain.
switch (Opc) {
/* FP Loads */
case X86::MOVAPDrm:
Expand Down Expand Up @@ -370,22 +414,34 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
/* Integer Loads */
case X86::MOVDQArm:
case X86::MOVDQUrm: {
return FixupConstant({{X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
{X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst}},
1);
FixupEntry Fixups[] = {
{HasSSE41 ? X86::PMOVSXBQrm : 0, 2, 8, rebuildSExtCst},
{X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
{HasSSE41 ? X86::PMOVSXBDrm : 0, 4, 8, rebuildSExtCst},
{HasSSE41 ? X86::PMOVSXWQrm : 0, 2, 16, rebuildSExtCst},
{X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
{HasSSE41 ? X86::PMOVSXBWrm : 0, 8, 8, rebuildSExtCst},
{HasSSE41 ? X86::PMOVSXWDrm : 0, 4, 16, rebuildSExtCst},
{HasSSE41 ? X86::PMOVSXDQrm : 0, 2, 32, rebuildSExtCst}};
return FixupConstant(Fixups, 1);
}
case X86::VMOVDQArm:
case X86::VMOVDQUrm: {
FixupEntry Fixups[] = {
{HasAVX2 ? X86::VPBROADCASTBrm : 0, 1, 8, rebuildSplatCst},
{HasAVX2 ? X86::VPBROADCASTWrm : 0, 1, 16, rebuildSplatCst},
{X86::VPMOVSXBQrm, 2, 8, rebuildSExtCst},
{X86::VMOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
{HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1, 32,
rebuildSplatCst},
{X86::VPMOVSXBDrm, 4, 8, rebuildSExtCst},
{X86::VPMOVSXWQrm, 2, 16, rebuildSExtCst},
{X86::VMOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
{HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1, 64,
rebuildSplatCst},
};
{X86::VPMOVSXBWrm, 8, 8, rebuildSExtCst},
{X86::VPMOVSXWDrm, 4, 16, rebuildSExtCst},
{X86::VPMOVSXDQrm, 2, 32, rebuildSExtCst}};
return FixupConstant(Fixups, 1);
}
case X86::VMOVDQAYrm:
Expand All @@ -395,10 +451,16 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
{HasAVX2 ? X86::VPBROADCASTWYrm : 0, 1, 16, rebuildSplatCst},
{HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1, 32,
rebuildSplatCst},
{HasAVX2 ? X86::VPMOVSXBQYrm : 0, 4, 8, rebuildSExtCst},
{HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1, 64,
rebuildSplatCst},
{HasAVX2 ? X86::VPMOVSXBDYrm : 0, 8, 8, rebuildSExtCst},
{HasAVX2 ? X86::VPMOVSXWQYrm : 0, 4, 16, rebuildSExtCst},
{HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1, 128,
rebuildSplatCst}};
rebuildSplatCst},
{HasAVX2 ? X86::VPMOVSXBWYrm : 0, 16, 8, rebuildSExtCst},
{HasAVX2 ? X86::VPMOVSXWDYrm : 0, 8, 16, rebuildSExtCst},
{HasAVX2 ? X86::VPMOVSXDQYrm : 0, 4, 32, rebuildSExtCst}};
return FixupConstant(Fixups, 1);
}
case X86::VMOVDQA32Z128rm:
Expand All @@ -408,10 +470,16 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
FixupEntry Fixups[] = {
{HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1, 8, rebuildSplatCst},
{HasBWI ? X86::VPBROADCASTWZ128rm : 0, 1, 16, rebuildSplatCst},
{X86::VPMOVSXBQZ128rm, 2, 8, rebuildSExtCst},
{X86::VMOVDI2PDIZrm, 1, 32, rebuildZeroUpperCst},
{X86::VPBROADCASTDZ128rm, 1, 32, rebuildSplatCst},
{X86::VPMOVSXBDZ128rm, 4, 8, rebuildSExtCst},
{X86::VPMOVSXWQZ128rm, 2, 16, rebuildSExtCst},
{X86::VMOVQI2PQIZrm, 1, 64, rebuildZeroUpperCst},
{X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst}};
{X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst},
{HasBWI ? X86::VPMOVSXBWZ128rm : 0, 8, 8, rebuildSExtCst},
{X86::VPMOVSXWDZ128rm, 4, 16, rebuildSExtCst},
{X86::VPMOVSXDQZ128rm, 2, 32, rebuildSExtCst}};
return FixupConstant(Fixups, 1);
}
case X86::VMOVDQA32Z256rm:
Expand All @@ -422,8 +490,14 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
{HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1, 8, rebuildSplatCst},
{HasBWI ? X86::VPBROADCASTWZ256rm : 0, 1, 16, rebuildSplatCst},
{X86::VPBROADCASTDZ256rm, 1, 32, rebuildSplatCst},
{X86::VPMOVSXBQZ256rm, 4, 8, rebuildSExtCst},
{X86::VPBROADCASTQZ256rm, 1, 64, rebuildSplatCst},
{X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst}};
{X86::VPMOVSXBDZ256rm, 8, 8, rebuildSExtCst},
{X86::VPMOVSXWQZ256rm, 4, 16, rebuildSExtCst},
{X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst},
{HasBWI ? X86::VPMOVSXBWZ256rm : 0, 16, 8, rebuildSExtCst},
{X86::VPMOVSXWDZ256rm, 8, 16, rebuildSExtCst},
{X86::VPMOVSXDQZ256rm, 4, 32, rebuildSExtCst}};
return FixupConstant(Fixups, 1);
}
case X86::VMOVDQA32Zrm:
Expand All @@ -435,8 +509,14 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
{HasBWI ? X86::VPBROADCASTWZrm : 0, 1, 16, rebuildSplatCst},
{X86::VPBROADCASTDZrm, 1, 32, rebuildSplatCst},
{X86::VPBROADCASTQZrm, 1, 64, rebuildSplatCst},
{X86::VPMOVSXBQZrm, 8, 8, rebuildSExtCst},
{X86::VBROADCASTI32X4rm, 1, 128, rebuildSplatCst},
{X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst}};
{X86::VPMOVSXBDZrm, 16, 8, rebuildSExtCst},
{X86::VPMOVSXWQZrm, 8, 16, rebuildSExtCst},
{X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst},
{HasBWI ? X86::VPMOVSXBWZrm : 0, 32, 8, rebuildSExtCst},
{X86::VPMOVSXWDZrm, 16, 16, rebuildSExtCst},
{X86::VPMOVSXDQZrm, 8, 32, rebuildSExtCst}};
return FixupConstant(Fixups, 1);
}
}
Expand Down
62 changes: 61 additions & 1 deletion llvm/lib/Target/X86/X86MCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,36 @@ static void printBroadcast(const MachineInstr *MI, MCStreamer &OutStreamer,
}
}

static bool printSignExtend(const MachineInstr *MI, MCStreamer &OutStreamer,
int SrcEltBits, int DstEltBits) {
auto *C = X86::getConstantFromPool(*MI, 1);
if (C && C->getType()->getScalarSizeInBits() == SrcEltBits) {
if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
int NumElts = CDS->getNumElements();
std::string Comment;
raw_string_ostream CS(Comment);

const MachineOperand &DstOp = MI->getOperand(0);
CS << X86ATTInstPrinter::getRegisterName(DstOp.getReg()) << " = ";
CS << "[";
for (int i = 0; i != NumElts; ++i) {
if (i != 0)
CS << ",";
if (CDS->getElementType()->isIntegerTy()) {
APInt Elt = CDS->getElementAsAPInt(i).sext(DstEltBits);
printConstant(Elt, CS);
} else
CS << "?";
}
CS << "]";
OutStreamer.AddComment(CS.str());
return true;
}
}

return false;
}

void X86AsmPrinter::EmitSEHInstruction(const MachineInstr *MI) {
assert(MF->hasWinCFI() && "SEH_ instruction in function without WinCFI?");
assert((getSubtarget().isOSWindows() || TM.getTargetTriple().isUEFI()) &&
Expand Down Expand Up @@ -1844,7 +1874,7 @@ static void addConstantComments(const MachineInstr *MI,
case X86::VMOVQI2PQIrm:
case X86::VMOVQI2PQIZrm:
printZeroUpperMove(MI, OutStreamer, 64, 128, "mem[0],zero");
break;
break;

case X86::MOVSSrm:
case X86::VMOVSSrm:
Expand Down Expand Up @@ -1979,6 +2009,36 @@ static void addConstantComments(const MachineInstr *MI,
case X86::VPBROADCASTBZrm:
printBroadcast(MI, OutStreamer, 64, 8);
break;

#define MOVX_CASE(Prefix, Ext, Type, Suffix) \
case X86::Prefix##PMOV##Ext##Type##Suffix##rm:

#define CASE_MOVX_RM(Ext, Type) \
MOVX_CASE(, Ext, Type, ) \
MOVX_CASE(V, Ext, Type, ) \
MOVX_CASE(V, Ext, Type, Y) \
MOVX_CASE(V, Ext, Type, Z128) \
MOVX_CASE(V, Ext, Type, Z256) \
MOVX_CASE(V, Ext, Type, Z)

CASE_MOVX_RM(SX, BD)
printSignExtend(MI, OutStreamer, 8, 32);
break;
CASE_MOVX_RM(SX, BQ)
printSignExtend(MI, OutStreamer, 8, 64);
break;
CASE_MOVX_RM(SX, BW)
printSignExtend(MI, OutStreamer, 8, 16);
break;
CASE_MOVX_RM(SX, DQ)
printSignExtend(MI, OutStreamer, 32, 64);
break;
CASE_MOVX_RM(SX, WD)
printSignExtend(MI, OutStreamer, 16, 32);
break;
CASE_MOVX_RM(SX, WQ)
printSignExtend(MI, OutStreamer, 16, 64);
break;
}
}

Expand Down
Loading