Skip to content

Commit bca846d

Browse files
authored
[AArch64] Improve mull generation (#114997)
This attempts to clean up and improve where we generate smull/umull using known-bits. For v2i64 types (where no mul is present), we try to create mull more aggressively to avoid scalarization.
1 parent aeb88f6 commit bca846d

File tree

2 files changed

+51
-144
lines changed

2 files changed

+51
-144
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 21 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5186,40 +5186,6 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
51865186
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op);
51875187
}
51885188

5189-
static EVT getExtensionTo64Bits(const EVT &OrigVT) {
5190-
if (OrigVT.getSizeInBits() >= 64)
5191-
return OrigVT;
5192-
5193-
assert(OrigVT.isSimple() && "Expecting a simple value type");
5194-
5195-
MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
5196-
switch (OrigSimpleTy) {
5197-
default: llvm_unreachable("Unexpected Vector Type");
5198-
case MVT::v2i8:
5199-
case MVT::v2i16:
5200-
return MVT::v2i32;
5201-
case MVT::v4i8:
5202-
return MVT::v4i16;
5203-
}
5204-
}
5205-
5206-
static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
5207-
const EVT &OrigTy,
5208-
const EVT &ExtTy,
5209-
unsigned ExtOpcode) {
5210-
// The vector originally had a size of OrigTy. It was then extended to ExtTy.
5211-
// We expect the ExtTy to be 128-bits total. If the OrigTy is less than
5212-
// 64-bits we need to insert a new extension so that it will be 64-bits.
5213-
assert(ExtTy.is128BitVector() && "Unexpected extension size");
5214-
if (OrigTy.getSizeInBits() >= 64)
5215-
return N;
5216-
5217-
// Must extend size to at least 64 bits to be used as an operand for VMULL.
5218-
EVT NewVT = getExtensionTo64Bits(OrigTy);
5219-
5220-
return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
5221-
}
5222-
52235189
// Returns lane if Op extracts from a two-element vector and lane is constant
52245190
// (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise.
52255191
static std::optional<uint64_t>
@@ -5265,31 +5231,11 @@ static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
52655231
static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
52665232
EVT VT = N.getValueType();
52675233
assert(VT.is128BitVector() && "Unexpected vector MULL size");
5268-
5269-
unsigned NumElts = VT.getVectorNumElements();
5270-
unsigned OrigEltSize = VT.getScalarSizeInBits();
5271-
unsigned EltSize = OrigEltSize / 2;
5272-
MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
5273-
5274-
APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
5275-
if (DAG.MaskedValueIsZero(N, HiBits))
5276-
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
5277-
5278-
if (ISD::isExtOpcode(N.getOpcode()))
5279-
return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
5280-
N.getOperand(0).getValueType(), VT,
5281-
N.getOpcode());
5282-
5283-
assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
5284-
SDLoc dl(N);
5285-
SmallVector<SDValue, 8> Ops;
5286-
for (unsigned i = 0; i != NumElts; ++i) {
5287-
const APInt &CInt = N.getConstantOperandAPInt(i);
5288-
// Element types smaller than 32 bits are not legal, so use i32 elements.
5289-
// The values are implicitly truncated so sext vs. zext doesn't matter.
5290-
Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
5291-
}
5292-
return DAG.getBuildVector(TruncVT, dl, Ops);
5234+
EVT HalfVT = EVT::getVectorVT(
5235+
*DAG.getContext(),
5236+
VT.getScalarType().getHalfSizedIntegerVT(*DAG.getContext()),
5237+
VT.getVectorElementCount());
5238+
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), HalfVT, N);
52935239
}
52945240

52955241
static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
@@ -5465,33 +5411,26 @@ static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
54655411
if (IsN0ZExt && IsN1ZExt)
54665412
return AArch64ISD::UMULL;
54675413

5468-
// Select SMULL if we can replace zext with sext.
5469-
if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) &&
5470-
!isExtendedBUILD_VECTOR(N0, DAG, false) &&
5471-
!isExtendedBUILD_VECTOR(N1, DAG, false)) {
5472-
SDValue ZextOperand;
5473-
if (IsN0ZExt)
5474-
ZextOperand = N0.getOperand(0);
5475-
else
5476-
ZextOperand = N1.getOperand(0);
5477-
if (DAG.SignBitIsZero(ZextOperand)) {
5478-
SDValue NewSext =
5479-
DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
5480-
if (IsN0ZExt)
5481-
N0 = NewSext;
5482-
else
5483-
N1 = NewSext;
5484-
return AArch64ISD::SMULL;
5485-
}
5486-
}
5487-
54885414
// Select UMULL if we can replace the other operand with an extend.
5415+
EVT VT = N0.getValueType();
5416+
unsigned EltSize = VT.getScalarSizeInBits();
5417+
APInt Mask = APInt::getHighBitsSet(EltSize, EltSize / 2);
54895418
if (IsN0ZExt || IsN1ZExt) {
5490-
EVT VT = N0.getValueType();
5491-
APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
5492-
VT.getScalarSizeInBits() / 2);
54935419
if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
54945420
return AArch64ISD::UMULL;
5421+
} else if (VT == MVT::v2i64 && DAG.MaskedValueIsZero(N0, Mask) &&
5422+
DAG.MaskedValueIsZero(N1, Mask)) {
5423+
// For v2i64 we look more aggresively at both operands being zero, to avoid
5424+
// scalarization.
5425+
return AArch64ISD::UMULL;
5426+
}
5427+
5428+
if (IsN0SExt || IsN1SExt) {
5429+
if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) > EltSize / 2)
5430+
return AArch64ISD::SMULL;
5431+
} else if (VT == MVT::v2i64 && DAG.ComputeNumSignBits(N0) > EltSize / 2 &&
5432+
DAG.ComputeNumSignBits(N1) > EltSize / 2) {
5433+
return AArch64ISD::SMULL;
54955434
}
54965435

54975436
if (!IsN1SExt && !IsN1ZExt)

llvm/test/CodeGen/AArch64/aarch64-smull.ll

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -231,29 +231,24 @@ define <4 x i32> @smull_zext_v4i16_v4i32(ptr %A, ptr %B) nounwind {
231231
define <2 x i64> @smull_zext_v2i32_v2i64(ptr %A, ptr %B) nounwind {
232232
; CHECK-NEON-LABEL: smull_zext_v2i32_v2i64:
233233
; CHECK-NEON: // %bb.0:
234-
; CHECK-NEON-NEXT: ldr d0, [x1]
235-
; CHECK-NEON-NEXT: ldrh w9, [x0]
236-
; CHECK-NEON-NEXT: ldrh w10, [x0, #2]
237-
; CHECK-NEON-NEXT: sshll v0.2d, v0.2s, #0
238-
; CHECK-NEON-NEXT: fmov x11, d0
239-
; CHECK-NEON-NEXT: mov x8, v0.d[1]
240-
; CHECK-NEON-NEXT: smull x9, w9, w11
241-
; CHECK-NEON-NEXT: smull x8, w10, w8
242-
; CHECK-NEON-NEXT: fmov d0, x9
243-
; CHECK-NEON-NEXT: mov v0.d[1], x8
234+
; CHECK-NEON-NEXT: ldrh w8, [x0]
235+
; CHECK-NEON-NEXT: ldrh w9, [x0, #2]
236+
; CHECK-NEON-NEXT: ldr d1, [x1]
237+
; CHECK-NEON-NEXT: fmov d0, x8
238+
; CHECK-NEON-NEXT: mov v0.d[1], x9
239+
; CHECK-NEON-NEXT: xtn v0.2s, v0.2d
240+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
244241
; CHECK-NEON-NEXT: ret
245242
;
246243
; CHECK-SVE-LABEL: smull_zext_v2i32_v2i64:
247244
; CHECK-SVE: // %bb.0:
248245
; CHECK-SVE-NEXT: ldrh w8, [x0]
249246
; CHECK-SVE-NEXT: ldrh w9, [x0, #2]
250-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
251-
; CHECK-SVE-NEXT: ldr d0, [x1]
252-
; CHECK-SVE-NEXT: fmov d1, x8
253-
; CHECK-SVE-NEXT: sshll v0.2d, v0.2s, #0
254-
; CHECK-SVE-NEXT: mov v1.d[1], x9
255-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
256-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
247+
; CHECK-SVE-NEXT: ldr d1, [x1]
248+
; CHECK-SVE-NEXT: fmov d0, x8
249+
; CHECK-SVE-NEXT: mov v0.d[1], x9
250+
; CHECK-SVE-NEXT: xtn v0.2s, v0.2d
251+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
257252
; CHECK-SVE-NEXT: ret
258253
;
259254
; CHECK-GI-LABEL: smull_zext_v2i32_v2i64:
@@ -2404,25 +2399,16 @@ define <2 x i32> @do_stuff(<2 x i64> %0, <2 x i64> %1) {
24042399
define <2 x i64> @lsr(<2 x i64> %a, <2 x i64> %b) {
24052400
; CHECK-NEON-LABEL: lsr:
24062401
; CHECK-NEON: // %bb.0:
2407-
; CHECK-NEON-NEXT: ushr v0.2d, v0.2d, #32
2408-
; CHECK-NEON-NEXT: ushr v1.2d, v1.2d, #32
2409-
; CHECK-NEON-NEXT: fmov x10, d1
2410-
; CHECK-NEON-NEXT: fmov x11, d0
2411-
; CHECK-NEON-NEXT: mov x8, v1.d[1]
2412-
; CHECK-NEON-NEXT: mov x9, v0.d[1]
2413-
; CHECK-NEON-NEXT: umull x10, w11, w10
2414-
; CHECK-NEON-NEXT: umull x8, w9, w8
2415-
; CHECK-NEON-NEXT: fmov d0, x10
2416-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2402+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2403+
; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
2404+
; CHECK-NEON-NEXT: umull v0.2d, v0.2s, v1.2s
24172405
; CHECK-NEON-NEXT: ret
24182406
;
24192407
; CHECK-SVE-LABEL: lsr:
24202408
; CHECK-SVE: // %bb.0:
2421-
; CHECK-SVE-NEXT: ushr v0.2d, v0.2d, #32
2422-
; CHECK-SVE-NEXT: ushr v1.2d, v1.2d, #32
2423-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2424-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2425-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2409+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2410+
; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32
2411+
; CHECK-SVE-NEXT: umull v0.2d, v0.2s, v1.2s
24262412
; CHECK-SVE-NEXT: ret
24272413
;
24282414
; CHECK-GI-LABEL: lsr:
@@ -2481,25 +2467,16 @@ define <2 x i64> @lsr_const(<2 x i64> %a, <2 x i64> %b) {
24812467
define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
24822468
; CHECK-NEON-LABEL: asr:
24832469
; CHECK-NEON: // %bb.0:
2484-
; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32
2485-
; CHECK-NEON-NEXT: sshr v1.2d, v1.2d, #32
2486-
; CHECK-NEON-NEXT: fmov x10, d1
2487-
; CHECK-NEON-NEXT: fmov x11, d0
2488-
; CHECK-NEON-NEXT: mov x8, v1.d[1]
2489-
; CHECK-NEON-NEXT: mov x9, v0.d[1]
2490-
; CHECK-NEON-NEXT: smull x10, w11, w10
2491-
; CHECK-NEON-NEXT: smull x8, w9, w8
2492-
; CHECK-NEON-NEXT: fmov d0, x10
2493-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2470+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2471+
; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
2472+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
24942473
; CHECK-NEON-NEXT: ret
24952474
;
24962475
; CHECK-SVE-LABEL: asr:
24972476
; CHECK-SVE: // %bb.0:
2498-
; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32
2499-
; CHECK-SVE-NEXT: sshr v1.2d, v1.2d, #32
2500-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2501-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2502-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2477+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2478+
; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32
2479+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
25032480
; CHECK-SVE-NEXT: ret
25042481
;
25052482
; CHECK-GI-LABEL: asr:
@@ -2524,25 +2501,16 @@ define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
25242501
define <2 x i64> @asr_const(<2 x i64> %a, <2 x i64> %b) {
25252502
; CHECK-NEON-LABEL: asr_const:
25262503
; CHECK-NEON: // %bb.0:
2527-
; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32
2528-
; CHECK-NEON-NEXT: fmov x9, d0
2529-
; CHECK-NEON-NEXT: mov x8, v0.d[1]
2530-
; CHECK-NEON-NEXT: lsl x10, x9, #5
2531-
; CHECK-NEON-NEXT: lsl x11, x8, #5
2532-
; CHECK-NEON-NEXT: sub x9, x10, x9
2533-
; CHECK-NEON-NEXT: fmov d0, x9
2534-
; CHECK-NEON-NEXT: sub x8, x11, x8
2535-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2504+
; CHECK-NEON-NEXT: movi v1.2s, #31
2505+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2506+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
25362507
; CHECK-NEON-NEXT: ret
25372508
;
25382509
; CHECK-SVE-LABEL: asr_const:
25392510
; CHECK-SVE: // %bb.0:
2540-
; CHECK-SVE-NEXT: mov w8, #31 // =0x1f
2541-
; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32
2542-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2543-
; CHECK-SVE-NEXT: dup v1.2d, x8
2544-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2545-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2511+
; CHECK-SVE-NEXT: movi v1.2s, #31
2512+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2513+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
25462514
; CHECK-SVE-NEXT: ret
25472515
;
25482516
; CHECK-GI-LABEL: asr_const:

0 commit comments

Comments
 (0)