-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[VectorCombine] Add type shrinking and zext propagation for fixed-width vector types #104606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -119,6 +119,7 @@ class VectorCombine { | |
bool foldShuffleFromReductions(Instruction &I); | ||
bool foldCastFromReductions(Instruction &I); | ||
bool foldSelectShuffle(Instruction &I, bool FromReduction = false); | ||
bool shrinkType(Instruction &I); | ||
|
||
void replaceValue(Value &Old, Value &New) { | ||
Old.replaceAllUsesWith(&New); | ||
|
@@ -2493,6 +2494,96 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { | |
return true; | ||
} | ||
|
||
/// Check if instruction depends on ZExt and this ZExt can be moved after the | ||
/// instruction. Move ZExt if it is profitable. For example: | ||
/// logic(zext(x),y) -> zext(logic(x,trunc(y))) | ||
/// lshr((zext(x),y) -> zext(lshr(x,trunc(y))) | ||
/// Cost model calculations takes into account if zext(x) has other users and | ||
/// whether it can be propagated through them too. | ||
bool VectorCombine::shrinkType(llvm::Instruction &I) { | ||
Value *ZExted, *OtherOperand; | ||
if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)), | ||
m_Value(OtherOperand))) && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if both are zero-extended? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If both ZExts are applied to the same type, the inst-combine will handle it. And if they are different, inst-combine will transform into the form recognised by this patch: |
||
!match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) | ||
return false; | ||
|
||
Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0); | ||
|
||
auto *BigTy = cast<FixedVectorType>(I.getType()); | ||
auto *SmallTy = cast<FixedVectorType>(ZExted->getType()); | ||
unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits(); | ||
|
||
// Check that the expression overall uses at most the same number of bits as | ||
// ZExted | ||
KnownBits KB = computeKnownBits(&I, *DL); | ||
if (KB.countMaxActiveBits() > BW) | ||
return false; | ||
|
||
// Calculate costs of leaving current IR as it is and moving ZExt operation | ||
// later, along with adding truncates if needed | ||
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; | ||
InstructionCost ZExtCost = TTI.getCastInstrCost( | ||
igogo-x86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Instruction::ZExt, BigTy, SmallTy, | ||
TargetTransformInfo::CastContextHint::None, CostKind); | ||
InstructionCost CurrentCost = ZExtCost; | ||
InstructionCost ShrinkCost = 0; | ||
|
||
// Calculate total cost and check that we can propagate through all ZExt users | ||
for (User *U : ZExtOperand->users()) { | ||
auto *UI = cast<Instruction>(U); | ||
if (UI == &I) { | ||
CurrentCost += | ||
TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind); | ||
ShrinkCost += | ||
TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind); | ||
ShrinkCost += ZExtCost; | ||
continue; | ||
} | ||
|
||
if (!Instruction::isBinaryOp(UI->getOpcode())) | ||
return false; | ||
|
||
// Check if we can propagate ZExt through its other users | ||
KB = computeKnownBits(UI, *DL); | ||
if (KB.countMaxActiveBits() > BW) | ||
return false; | ||
|
||
CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind); | ||
ShrinkCost += | ||
TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind); | ||
ShrinkCost += ZExtCost; | ||
} | ||
|
||
// If the other instruction operand is not a constant, we'll need to | ||
// generate a truncate instruction. So we have to adjust cost | ||
if (!isa<Constant>(OtherOperand)) | ||
ShrinkCost += TTI.getCastInstrCost( | ||
Instruction::Trunc, SmallTy, BigTy, | ||
TargetTransformInfo::CastContextHint::None, CostKind); | ||
|
||
// If the cost of shrinking types and leaving the IR is the same, we'll lean | ||
// towards modifying the IR because shrinking opens opportunities for other | ||
// shrinking optimisations. | ||
if (ShrinkCost > CurrentCost) | ||
return false; | ||
|
||
Value *Op0 = ZExted; | ||
if (auto *OI = dyn_cast<Instruction>(OtherOperand)) | ||
Builder.SetInsertPoint(OI->getNextNode()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is currently broken. What if OI->getNextNode() is a PHI?
with
and I think it's because the trunc created on the next line is inserted before the second PHI in the bb. If you simply comment out the
code at line 2567 above it happens in tree as well. (I'm sure the testcase can be modified in some way so it happens even with the cost comparison at 2567 for some target but I didn't manage right now.) bbi-99058.ll in my example is
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created a fix - #108228 |
||
Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy); | ||
Builder.SetInsertPoint(&I); | ||
// Keep the order of operands the same | ||
if (I.getOperand(0) == OtherOperand) | ||
std::swap(Op0, Op1); | ||
Value *NewBinOp = | ||
Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1); | ||
cast<Instruction>(NewBinOp)->copyIRFlags(&I); | ||
cast<Instruction>(NewBinOp)->copyMetadata(I); | ||
Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy); | ||
replaceValue(I, *NewZExtr); | ||
return true; | ||
} | ||
|
||
/// This is the entry point for all transforms. Pass manager differences are | ||
/// handled in the callers of this function. | ||
bool VectorCombine::run() { | ||
|
@@ -2560,6 +2651,9 @@ bool VectorCombine::run() { | |
case Instruction::BitCast: | ||
MadeChange |= foldBitcastShuffle(I); | ||
break; | ||
default: | ||
MadeChange |= shrinkType(I); | ||
break; | ||
} | ||
} else { | ||
switch (Opcode) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py | ||
; RUN: opt -passes=vector-combine -S %s | FileCheck %s | ||
|
||
target triple = "aarch64" | ||
|
||
define i32 @test_and(<16 x i32> %a, ptr %b) { | ||
; CHECK-LABEL: @test_and( | ||
; CHECK-NEXT: entry: | ||
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1 | ||
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8> | ||
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]] | ||
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32> | ||
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]]) | ||
; CHECK-NEXT: ret i32 [[TMP3]] | ||
; | ||
entry: | ||
%wide.load = load <16 x i8>, ptr %b, align 1 | ||
%0 = zext <16 x i8> %wide.load to <16 x i32> | ||
%1 = and <16 x i32> %0, %a | ||
%2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1) | ||
ret i32 %2 | ||
} | ||
|
||
define i32 @test_mask_or(<16 x i32> %a, ptr %b) { | ||
; CHECK-LABEL: @test_mask_or( | ||
; CHECK-NEXT: entry: | ||
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1 | ||
; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16> | ||
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8> | ||
; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]] | ||
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32> | ||
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]]) | ||
; CHECK-NEXT: ret i32 [[TMP3]] | ||
; | ||
entry: | ||
%wide.load = load <16 x i8>, ptr %b, align 1 | ||
%a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16> | ||
%0 = zext <16 x i8> %wide.load to <16 x i32> | ||
%1 = or <16 x i32> %0, %a.masked | ||
%2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1) | ||
ret i32 %2 | ||
} | ||
|
||
define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) { | ||
; CHECK-LABEL: @multiuse( | ||
; CHECK-NEXT: entry: | ||
; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255> | ||
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8> | ||
; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255> | ||
; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8> | ||
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1 | ||
; CHECK-NEXT: [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4> | ||
; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]] | ||
; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32> | ||
; CHECK-NEXT: [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15> | ||
; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]] | ||
; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32> | ||
; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]] | ||
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]]) | ||
; CHECK-NEXT: ret i32 [[TMP9]] | ||
; | ||
entry: | ||
%u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255> | ||
%v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255> | ||
%wide.load = load <16 x i8>, ptr %b, align 1 | ||
%0 = zext <16 x i8> %wide.load to <16 x i32> | ||
%1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4> | ||
%2 = or <16 x i32> %1, %v.masked | ||
%3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15> | ||
%4 = or <16 x i32> %3, %u.masked | ||
%5 = add nuw nsw <16 x i32> %2, %4 | ||
%6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5) | ||
ret i32 %6 | ||
} | ||
|
||
declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lshr((zext(x),y)
->lshr(zext(x),y)