Skip to content

[DAG][AArch64] Handle vscale addressing modes in reassociationCanBreakAddressingModePattern #89908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,44 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
// (load/store (add, (add, x, y), offset2)) ->
// (load/store (add, (add, x, offset2), y)).

if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
if (N0.getOpcode() != ISD::ADD)
return false;

// Check for vscale addressing modes.
// (load/store (add/sub (add x, y), vscale))
// (load/store (add/sub (add x, y), (lsl vscale, C)))
// (load/store (add/sub (add x, y), (mul vscale, C)))
if ((N1.getOpcode() == ISD::VSCALE ||
((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
N1.getOperand(0).getOpcode() == ISD::VSCALE &&
isa<ConstantSDNode>(N1.getOperand(1)))) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use getScalarValueSizeInBits here instead to avoid the implicit TypeSize->uint64_t conversion?

N1.getValueType().getFixedSizeInBits() <= 64) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of things I've noticed here:

  1. The AddrMode struct's ScalableOffset field is actually int64_t and the offset can be negative. e.g. add (add x, y), vscale(-3). I think we need to use int64_t here.
  2. This code potentially accepts ISD::VSCALE nodes with result types of 8 or 16-bits, where there is a chance of overflow for large values of vscale. Is it worth just restricting this to integer types that are suitable for pointer arithmetic?

int64_t ScalableOffset =
N1.getOpcode() == ISD::VSCALE
? N1.getConstantOperandVal(0)
: (N1.getOperand(0).getConstantOperandVal(0) *
(N1.getOpcode() == ISD::SHL ? (1 << N1.getConstantOperandVal(1))
: N1.getConstantOperandVal(1)));
if (Opc == ISD::SUB)
ScalableOffset = -ScalableOffset;
if (all_of(N->uses(), [&](SDNode *Node) {
if (auto *LoadStore = dyn_cast<MemSDNode>(Node);
LoadStore && LoadStore->getBasePtr().getNode() == N) {
TargetLoweringBase::AddrMode AM;
AM.HasBaseReg = true;
AM.ScalableOffset = ScalableOffset;
EVT VT = LoadStore->getMemoryVT();
unsigned AS = LoadStore->getAddressSpace();
Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy,
AS);
}
return false;
}))
return true;
}

if (Opc != ISD::ADD)
return false;

auto *C2 = dyn_cast<ConstantSDNode>(N1);
Expand Down Expand Up @@ -3911,7 +3948,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {

// Hoist one-use addition by non-opaque constant:
// (x + C) - y -> (x - y) + C
if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
if (!reassociationCanBreakAddressingModePattern(ISD::SUB, DL, N, N0, N1) &&
N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
Expand Down
54 changes: 18 additions & 36 deletions llvm/test/CodeGen/AArch64/sve-reassocadd.ll
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ entry:
define <vscale x 16 x i8> @i8_4s_1v(ptr %b) {
; CHECK-LABEL: i8_4s_1v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: mov w9, #4 // =0x4
; CHECK-NEXT: add x8, x0, x8
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x8, x9]
; CHECK-NEXT: add x8, x0, #4
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x8, #1, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 4
Expand Down Expand Up @@ -58,11 +56,9 @@ entry:
define <vscale x 8 x i16> @i16_8s_1v(ptr %b) {
; CHECK-LABEL: i16_8s_1v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: mov x9, #4 // =0x4
; CHECK-NEXT: add x8, x0, x8
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x8, x9, lsl #1]
; CHECK-NEXT: add x8, x0, #8
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x8, #1, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 8
Expand Down Expand Up @@ -94,11 +90,9 @@ entry:
define <vscale x 8 x i16> @i16_8s_2v(ptr %b) {
; CHECK-LABEL: i16_8s_2v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: rdvl x8, #2
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: mov x9, #4 // =0x4
; CHECK-NEXT: add x8, x0, x8
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x8, x9, lsl #1]
; CHECK-NEXT: add x8, x0, #8
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x8, #2, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 8
Expand Down Expand Up @@ -130,11 +124,9 @@ entry:
define <vscale x 4 x i32> @i32_16s_2v(ptr %b) {
; CHECK-LABEL: i32_16s_2v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: mov x9, #4 // =0x4
; CHECK-NEXT: add x8, x0, x8
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x8, x9, lsl #2]
; CHECK-NEXT: add x8, x0, #16
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x8, #1, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 16
Expand Down Expand Up @@ -166,11 +158,9 @@ entry:
define <vscale x 2 x i64> @i64_32s_2v(ptr %b) {
; CHECK-LABEL: i64_32s_2v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: mov x9, #4 // =0x4
; CHECK-NEXT: add x8, x0, x8
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, x9, lsl #3]
; CHECK-NEXT: add x8, x0, #32
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, #1, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 32
Expand Down Expand Up @@ -203,11 +193,9 @@ entry:
define <vscale x 16 x i8> @i8_4s_m2v(ptr %b) {
; CHECK-LABEL: i8_4s_m2v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cnth x8, all, mul #4
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: mov w9, #4 // =0x4
; CHECK-NEXT: sub x8, x0, x8
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x8, x9]
; CHECK-NEXT: add x8, x0, #4
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x8, #-2, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 4
Expand Down Expand Up @@ -239,11 +227,9 @@ entry:
define <vscale x 8 x i16> @i16_8s_m2v(ptr %b) {
; CHECK-LABEL: i16_8s_m2v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cnth x8, all, mul #4
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: mov x9, #4 // =0x4
; CHECK-NEXT: sub x8, x0, x8
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x8, x9, lsl #1]
; CHECK-NEXT: add x8, x0, #8
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x8, #-2, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 8
Expand Down Expand Up @@ -275,11 +261,9 @@ entry:
define <vscale x 4 x i32> @i32_16s_m2v(ptr %b) {
; CHECK-LABEL: i32_16s_m2v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cnth x8, all, mul #4
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: mov x9, #4 // =0x4
; CHECK-NEXT: sub x8, x0, x8
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x8, x9, lsl #2]
; CHECK-NEXT: add x8, x0, #16
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x8, #-2, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 16
Expand Down Expand Up @@ -311,11 +295,9 @@ entry:
define <vscale x 2 x i64> @i64_32s_m2v(ptr %b) {
; CHECK-LABEL: i64_32s_m2v:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cnth x8, all, mul #4
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: mov x9, #4 // =0x4
; CHECK-NEXT: sub x8, x0, x8
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, x9, lsl #3]
; CHECK-NEXT: add x8, x0, #32
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, #-2, mul vl]
; CHECK-NEXT: ret
entry:
%add.ptr = getelementptr inbounds i8, ptr %b, i64 32
Expand Down