Skip to content

Commit 14e9fe7

Browse files
nilanjana87fhahn
authored andcommitted
[AArch64] Extending lowering of 'trunc <(8|16) x i64> %x to <(8|16) x i8>' to use tbl instructions
[AArch64] Patch for lowering trunc instructions to 'tbl' for (8|16)xi32 -> (8|16)xi8 conversions in https://reviews.llvm.org/D133495 is extended to support trunc to tbl lowering for (8|16) x i64 to (8|16) x i8. A microbenchmark for runtime for these transformations is added in https://reviews.llvm.org/D136274 Reviewed by: fhahn, t.p.northover Differential Revision: https://reviews.llvm.org/D135229 (cherry picked from commit 02d09ff)
1 parent 8dc6fe8 commit 14e9fe7

File tree

2 files changed

+319
-164
lines changed

2 files changed

+319
-164
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 112 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13743,38 +13743,116 @@ static void createTblShuffleForZExt(ZExtInst *ZExt, bool IsLittleEndian) {
1374313743
static void createTblForTrunc(TruncInst *TI, bool IsLittleEndian) {
1374413744
IRBuilder<> Builder(TI);
1374513745
SmallVector<Value *> Parts;
13746+
int NumElements = cast<FixedVectorType>(TI->getType())->getNumElements();
13747+
auto *SrcTy = cast<FixedVectorType>(TI->getOperand(0)->getType());
13748+
auto *DstTy = cast<FixedVectorType>(TI->getType());
13749+
assert(SrcTy->getElementType()->isIntegerTy() &&
13750+
"Non-integer type source vector element is not supported");
13751+
assert(DstTy->getElementType()->isIntegerTy(8) &&
13752+
"Unsupported destination vector element type");
13753+
unsigned SrcElemTySz =
13754+
cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
13755+
unsigned TruncFactor =
13756+
SrcElemTySz / cast<IntegerType>(DstTy->getElementType())->getBitWidth();
13757+
assert((SrcElemTySz == 16 || SrcElemTySz == 32 || SrcElemTySz == 64) &&
13758+
"Unsupported source vector element type size");
1374613759
Type *VecTy = FixedVectorType::get(Builder.getInt8Ty(), 16);
13747-
Parts.push_back(Builder.CreateBitCast(
13748-
Builder.CreateShuffleVector(TI->getOperand(0), {0, 1, 2, 3}), VecTy));
13749-
Parts.push_back(Builder.CreateBitCast(
13750-
Builder.CreateShuffleVector(TI->getOperand(0), {4, 5, 6, 7}), VecTy));
13751-
13752-
Intrinsic::ID TblID = Intrinsic::aarch64_neon_tbl2;
13753-
unsigned NumElements = cast<FixedVectorType>(TI->getType())->getNumElements();
13754-
if (NumElements == 16) {
13755-
Parts.push_back(Builder.CreateBitCast(
13756-
Builder.CreateShuffleVector(TI->getOperand(0), {8, 9, 10, 11}), VecTy));
13760+
13761+
// Create a mask to choose every nth byte from the source vector table of
13762+
// bytes to create the truncated destination vector, where 'n' is the truncate
13763+
// ratio. For example, for a truncate from Yxi64 to Yxi8, choose
13764+
// 0,8,16,..Y*8th bytes for the little-endian format
13765+
SmallVector<Constant *, 16> MaskConst;
13766+
for (int Itr = 0; Itr < 16; Itr++) {
13767+
if (Itr < NumElements)
13768+
MaskConst.push_back(ConstantInt::get(
13769+
Builder.getInt8Ty(), IsLittleEndian
13770+
? Itr * TruncFactor
13771+
: Itr * TruncFactor + (TruncFactor - 1)));
13772+
else
13773+
MaskConst.push_back(ConstantInt::get(Builder.getInt8Ty(), 255));
13774+
}
13775+
13776+
int MaxTblSz = 128 * 4;
13777+
int MaxSrcSz = SrcElemTySz * NumElements;
13778+
int ElemsPerTbl =
13779+
(MaxTblSz > MaxSrcSz) ? NumElements : (MaxTblSz / SrcElemTySz);
13780+
assert(ElemsPerTbl <= 16 &&
13781+
"Maximum elements selected using TBL instruction cannot exceed 16!");
13782+
13783+
int ShuffleCount = 128 / SrcElemTySz;
13784+
SmallVector<int> ShuffleLanes;
13785+
for (int i = 0; i < ShuffleCount; ++i)
13786+
ShuffleLanes.push_back(i);
13787+
13788+
// Create TBL's table of bytes in 1,2,3 or 4 FP/SIMD registers using shuffles
13789+
// over the source vector. If TBL's maximum 4 FP/SIMD registers are saturated,
13790+
// call TBL & save the result in a vector of TBL results for combining later.
13791+
SmallVector<Value *> Results;
13792+
while (ShuffleLanes.back() < NumElements) {
1375713793
Parts.push_back(Builder.CreateBitCast(
13758-
Builder.CreateShuffleVector(TI->getOperand(0), {12, 13, 14, 15}),
13759-
VecTy));
13760-
TblID = Intrinsic::aarch64_neon_tbl4;
13794+
Builder.CreateShuffleVector(TI->getOperand(0), ShuffleLanes), VecTy));
13795+
13796+
if (Parts.size() >= 4) {
13797+
auto *F = Intrinsic::getDeclaration(TI->getModule(),
13798+
Intrinsic::aarch64_neon_tbl4, VecTy);
13799+
Parts.push_back(ConstantVector::get(MaskConst));
13800+
Results.push_back(Builder.CreateCall(F, Parts));
13801+
Parts.clear();
13802+
}
13803+
13804+
for (int i = 0; i < ShuffleCount; ++i)
13805+
ShuffleLanes[i] += ShuffleCount;
1376113806
}
13762-
SmallVector<Constant *, 16> MaskConst;
13763-
for (unsigned Idx = 0; Idx < NumElements * 4; Idx += 4)
13764-
MaskConst.push_back(
13765-
ConstantInt::get(Builder.getInt8Ty(), IsLittleEndian ? Idx : Idx + 3));
1376613807

13767-
for (unsigned Idx = NumElements * 4; Idx < 64; Idx += 4)
13768-
MaskConst.push_back(ConstantInt::get(Builder.getInt8Ty(), 255));
13808+
assert((Parts.empty() || Results.empty()) &&
13809+
"Lowering trunc for vectors requiring different TBL instructions is "
13810+
"not supported!");
13811+
// Call TBL for the residual table bytes present in 1,2, or 3 FP/SIMD
13812+
// registers
13813+
if (!Parts.empty()) {
13814+
Intrinsic::ID TblID;
13815+
switch (Parts.size()) {
13816+
case 1:
13817+
TblID = Intrinsic::aarch64_neon_tbl1;
13818+
break;
13819+
case 2:
13820+
TblID = Intrinsic::aarch64_neon_tbl2;
13821+
break;
13822+
case 3:
13823+
TblID = Intrinsic::aarch64_neon_tbl3;
13824+
break;
13825+
}
1376913826

13770-
Parts.push_back(ConstantVector::get(MaskConst));
13771-
auto *F =
13772-
Intrinsic::getDeclaration(TI->getModule(), TblID, Parts[0]->getType());
13773-
Value *Res = Builder.CreateCall(F, Parts);
13827+
auto *F = Intrinsic::getDeclaration(TI->getModule(), TblID, VecTy);
13828+
Parts.push_back(ConstantVector::get(MaskConst));
13829+
Results.push_back(Builder.CreateCall(F, Parts));
13830+
}
13831+
13832+
// Extract the destination vector from TBL result(s) after combining them
13833+
// where applicable. Currently, at most two TBLs are supported.
13834+
assert(Results.size() <= 2 && "Trunc lowering does not support generation of "
13835+
"more than 2 tbl instructions!");
13836+
Value *FinalResult = Results[0];
13837+
if (Results.size() == 1) {
13838+
if (ElemsPerTbl < 16) {
13839+
SmallVector<int> FinalMask(ElemsPerTbl);
13840+
std::iota(FinalMask.begin(), FinalMask.end(), 0);
13841+
FinalResult = Builder.CreateShuffleVector(Results[0], FinalMask);
13842+
}
13843+
} else {
13844+
SmallVector<int> FinalMask(ElemsPerTbl * Results.size());
13845+
if (ElemsPerTbl < 16) {
13846+
std::iota(FinalMask.begin(), FinalMask.begin() + ElemsPerTbl, 0);
13847+
std::iota(FinalMask.begin() + ElemsPerTbl, FinalMask.end(), 16);
13848+
} else {
13849+
std::iota(FinalMask.begin(), FinalMask.end(), 0);
13850+
}
13851+
FinalResult =
13852+
Builder.CreateShuffleVector(Results[0], Results[1], FinalMask);
13853+
}
1377413854

13775-
if (NumElements == 8)
13776-
Res = Builder.CreateShuffleVector(Res, {0, 1, 2, 3, 4, 5, 6, 7});
13777-
TI->replaceAllUsesWith(Res);
13855+
TI->replaceAllUsesWith(FinalResult);
1377813856
TI->eraseFromParent();
1377913857
}
1378013858

@@ -13836,13 +13914,15 @@ bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(Instruction *I,
1383613914
return true;
1383713915
}
1383813916

13839-
// Convert 'trunc <(8|16) x i32> %x to <(8|16) x i8>' to a single tbl.4
13840-
// instruction selecting the lowest 8 bits per lane of the input interpreted
13841-
// as 2 or 4 <4 x i32> vectors.
13917+
// Convert 'trunc <(8|16) x (i32|i64)> %x to <(8|16) x i8>' to an appropriate
13918+
// tbl instruction selecting the lowest/highest (little/big endian) 8 bits
13919+
// per lane of the input that is represented using 1,2,3 or 4 128-bit table
13920+
// registers
1384213921
auto *TI = dyn_cast<TruncInst>(I);
13843-
if (TI && (SrcTy->getNumElements() == 8 || SrcTy->getNumElements() == 16) &&
13844-
SrcTy->getElementType()->isIntegerTy(32) &&
13845-
DstTy->getElementType()->isIntegerTy(8)) {
13922+
if (TI && DstTy->getElementType()->isIntegerTy(8) &&
13923+
((SrcTy->getElementType()->isIntegerTy(32) ||
13924+
SrcTy->getElementType()->isIntegerTy(64)) &&
13925+
(SrcTy->getNumElements() == 16 || SrcTy->getNumElements() == 8))) {
1384613926
createTblForTrunc(TI, Subtarget->isLittleEndian());
1384713927
return true;
1384813928
}

0 commit comments

Comments
 (0)