Skip to content

[VectorCombine] Add type shrinking and zext propagation for fixed-width vector types #104606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class VectorCombine {
bool foldShuffleFromReductions(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
bool shrinkType(Instruction &I);

void replaceValue(Value &Old, Value &New) {
Old.replaceAllUsesWith(&New);
Expand Down Expand Up @@ -2493,6 +2494,96 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
return true;
}

/// Check if instruction depends on ZExt and this ZExt can be moved after the
/// instruction. Move ZExt if it is profitable. For example:
/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lshr((zext(x),y) -> lshr(zext(x),y)

/// Cost model calculations takes into account if zext(x) has other users and
/// whether it can be propagated through them too.
bool VectorCombine::shrinkType(llvm::Instruction &I) {
Value *ZExted, *OtherOperand;
if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
m_Value(OtherOperand))) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if both are zero-extended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If both ZExts are applied to the same type, the inst-combine will handle it. And if they are different, inst-combine will transform into the form recognised by this patch:

https://godbolt.org/z/41j8xYThP

!match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
return false;

Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);

auto *BigTy = cast<FixedVectorType>(I.getType());
auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();

// Check that the expression overall uses at most the same number of bits as
// ZExted
KnownBits KB = computeKnownBits(&I, *DL);
if (KB.countMaxActiveBits() > BW)
return false;

// Calculate costs of leaving current IR as it is and moving ZExt operation
// later, along with adding truncates if needed
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost ZExtCost = TTI.getCastInstrCost(
Instruction::ZExt, BigTy, SmallTy,
TargetTransformInfo::CastContextHint::None, CostKind);
InstructionCost CurrentCost = ZExtCost;
InstructionCost ShrinkCost = 0;

// Calculate total cost and check that we can propagate through all ZExt users
for (User *U : ZExtOperand->users()) {
auto *UI = cast<Instruction>(U);
if (UI == &I) {
CurrentCost +=
TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
ShrinkCost +=
TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
ShrinkCost += ZExtCost;
continue;
}

if (!Instruction::isBinaryOp(UI->getOpcode()))
return false;

// Check if we can propagate ZExt through its other users
KB = computeKnownBits(UI, *DL);
if (KB.countMaxActiveBits() > BW)
return false;

CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
ShrinkCost +=
TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
ShrinkCost += ZExtCost;
}

// If the other instruction operand is not a constant, we'll need to
// generate a truncate instruction. So we have to adjust cost
if (!isa<Constant>(OtherOperand))
ShrinkCost += TTI.getCastInstrCost(
Instruction::Trunc, SmallTy, BigTy,
TargetTransformInfo::CastContextHint::None, CostKind);

// If the cost of shrinking types and leaving the IR is the same, we'll lean
// towards modifying the IR because shrinking opens opportunities for other
// shrinking optimisations.
if (ShrinkCost > CurrentCost)
return false;

Value *Op0 = ZExted;
if (auto *OI = dyn_cast<Instruction>(OtherOperand))
Builder.SetInsertPoint(OI->getNextNode());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is currently broken.

What if OI->getNextNode() is a PHI?
For my out of tree target the following fails

opt -passes="vector-combine" bbi-99058.ll -o /dev/null

with

PHI nodes not grouped at top of basic block!
  %vec.ind = phi <4 x i16> [ zeroinitializer, %entry ], [ zeroinitializer, %vector.body ]
label %vector.body
LLVM ERROR: Broken module found, compilation aborted!

and I think it's because the trunc created on the next line is inserted before the second PHI in the bb.

If you simply comment out the

  if (ShrinkCost > CurrentCost)
    return false;

code at line 2567 above it happens in tree as well. (I'm sure the testcase can be modified in some way so it happens even with the cost comparison at 2567 for some target but I didn't manage right now.)

bbi-99058.ll in my example is

define i64 @func_1() {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry
  %vec.phi = phi <4 x i32> [ zeroinitializer, %entry ], [ %1, %vector.body ]
  %vec.ind = phi <4 x i16> [ zeroinitializer, %entry ], [ zeroinitializer, %vector.body ]
  %0 = zext <4 x i16> zeroinitializer to <4 x i32>
  %1 = and <4 x i32> %vec.phi, %0
  br label %vector.body
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a fix - #108228

Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
Builder.SetInsertPoint(&I);
// Keep the order of operands the same
if (I.getOperand(0) == OtherOperand)
std::swap(Op0, Op1);
Value *NewBinOp =
Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
cast<Instruction>(NewBinOp)->copyIRFlags(&I);
cast<Instruction>(NewBinOp)->copyMetadata(I);
Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
replaceValue(I, *NewZExtr);
return true;
}

/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
bool VectorCombine::run() {
Expand Down Expand Up @@ -2560,6 +2651,9 @@ bool VectorCombine::run() {
case Instruction::BitCast:
MadeChange |= foldBitcastShuffle(I);
break;
default:
MadeChange |= shrinkType(I);
break;
}
} else {
switch (Opcode) {
Expand Down
76 changes: 76 additions & 0 deletions llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes=vector-combine -S %s | FileCheck %s

target triple = "aarch64"

define i32 @test_and(<16 x i32> %a, ptr %b) {
; CHECK-LABEL: @test_and(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
; CHECK-NEXT: ret i32 [[TMP3]]
;
entry:
%wide.load = load <16 x i8>, ptr %b, align 1
%0 = zext <16 x i8> %wide.load to <16 x i32>
%1 = and <16 x i32> %0, %a
%2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
ret i32 %2
}

define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
; CHECK-LABEL: @test_mask_or(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
; CHECK-NEXT: ret i32 [[TMP3]]
;
entry:
%wide.load = load <16 x i8>, ptr %b, align 1
%a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
%0 = zext <16 x i8> %wide.load to <16 x i32>
%1 = or <16 x i32> %0, %a.masked
%2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
ret i32 %2
}

define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
; CHECK-LABEL: @multiuse(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
; CHECK-NEXT: [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4>
; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
; CHECK-NEXT: [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15>
; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]]
; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]]
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
; CHECK-NEXT: ret i32 [[TMP9]]
;
entry:
%u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
%v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
%wide.load = load <16 x i8>, ptr %b, align 1
%0 = zext <16 x i8> %wide.load to <16 x i32>
%1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4>
%2 = or <16 x i32> %1, %v.masked
%3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15>
%4 = or <16 x i32> %3, %u.masked
%5 = add nuw nsw <16 x i32> %2, %4
%6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
ret i32 %6
}

declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)
Loading