Skip to content

[AArch64] Improve operand sinking for mul instructions #116604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5168,26 +5168,45 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
return false;
}
case Instruction::Mul: {
auto ShouldSinkSplatForIndexedVariant = [](Value *V) {
auto *Ty = cast<VectorType>(V->getType());
// For SVE the lane-indexing is within 128-bits, so we can't fold splats.
if (Ty->isScalableTy())
return false;

// Indexed variants of Mul exist for i16 and i32 element types only.
return Ty->getScalarSizeInBits() == 16 || Ty->getScalarSizeInBits() == 32;
};

int NumZExts = 0, NumSExts = 0;
for (auto &Op : I->operands()) {
// Make sure we are not already sinking this operand
if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
continue;

if (match(&Op, m_SExt(m_Value()))) {
NumSExts++;
continue;
} else if (match(&Op, m_ZExt(m_Value()))) {
NumZExts++;
if (match(&Op, m_ZExtOrSExt(m_Value()))) {
auto *Ext = cast<Instruction>(Op);
auto *ExtOp = Ext->getOperand(0);
if (isSplatShuffle(ExtOp) && ShouldSinkSplatForIndexedVariant(ExtOp))
Ops.push_back(&Ext->getOperandUse(0));
Ops.push_back(&Op);

if (isa<SExtInst>(Ext))
NumSExts++;
else
NumZExts++;

continue;
}

ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
if (!Shuffle)
continue;

// If the Shuffle is a splat and the operand is a zext/sext, sinking the
// operand and the s/zext can help create indexed s/umull. This is
// especially useful to prevent i64 mul being scalarized.
if (Shuffle && isSplatShuffle(Shuffle) &&
if (isSplatShuffle(Shuffle) &&
match(Shuffle->getOperand(0), m_ZExtOrSExt(m_Value()))) {
Ops.push_back(&Shuffle->getOperandUse(0));
Ops.push_back(&Op);
Expand All @@ -5198,9 +5217,6 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
continue;
}

if (!Shuffle)
continue;

Value *ShuffleOperand = Shuffle->getOperand(0);
InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
if (!Insert)
Expand Down Expand Up @@ -5232,12 +5248,26 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
NumZExts++;
}

Ops.push_back(&Insert->getOperandUse(1));
Ops.push_back(&Shuffle->getOperandUse(0));
Ops.push_back(&Op);
}

// Is it profitable to sink if we found two of the same type of extends.
return !Ops.empty() && (NumSExts == 2 || NumZExts == 2);
// It is profitable to sink if we found two of the same type of extends.
if (!Ops.empty() && (NumSExts == 2 || NumZExts == 2))
return true;

// Otherwise, see if we should sink splats for indexed variants.
if (!ShouldSinkSplatForIndexedVariant(I))
return false;

Ops.clear();
if (isSplatShuffle(I->getOperand(0)))
Ops.push_back(&I->getOperandUse(0));
if (isSplatShuffle(I->getOperand(1)))
Ops.push_back(&I->getOperandUse(1));

return !Ops.empty();
}
default:
return false;
Expand Down
16 changes: 10 additions & 6 deletions llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ target triple = "aarch64-unknown-linux-gnu"
define dso_local i32 @dupext_crashtest(i32 %e) local_unnamed_addr {
; CHECK-LABEL: dupext_crashtest:
; CHECK: // %bb.0: // %for.body.lr.ph
; CHECK-NEXT: mov w8, w0
; CHECK-NEXT: dup v0.2s, w8
; CHECK-NEXT: .LBB0_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ldr d1, [x8]
; CHECK-NEXT: smull v1.2d, v0.2s, v1.2s
; CHECK-NEXT: xtn v1.2s, v1.2d
; CHECK-NEXT: str d1, [x8]
; CHECK-NEXT: ldr d0, [x8]
; CHECK-NEXT: ushll v0.2d, v0.2s, #0
; CHECK-NEXT: fmov x9, d0
; CHECK-NEXT: mov x8, v0.d[1]
; CHECK-NEXT: mul w9, w0, w9
; CHECK-NEXT: mul w8, w0, w8
; CHECK-NEXT: fmov d0, x9
; CHECK-NEXT: mov v0.d[1], x8
; CHECK-NEXT: xtn v0.2s, v0.2d
; CHECK-NEXT: str d0, [x8]
Comment on lines -17 to +24
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a regression, but I have a patch that fixes it by teaching

https://github.com/llvm/llvm-project/blob/c25c6c32494c8d1038438b6208d42ba40f25270e/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp#L18599

to handle ANY_EXTENDs, which seem to get generated via KnownBits queries when we visit the truncate nodes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Is it possible to write a separate test for it too, with the anyext already in place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seemed to make sense to put this into a seperate, follow-up PR - see #118308

Is it possible to write a separate test for it too, with the anyext already in place?

I've added the test @dupzext_v2i32_v2i64_trunc in that PR that should generate the anyext via the truncate - I'm not sure how I would do this otherwise, as unless I'm missing something there's no anyext in LLVM IR?

; CHECK-NEXT: b .LBB0_1
for.body.lr.ph:
%conv314 = zext i32 %e to i64
Expand Down
Loading
Loading