Skip to content

Commit 7a0b897

Browse files
[DAGCombiner][SVE] Ensure MGATHER/MSCATTER addressing mode combines preserve index scaling
refineUniformBase and selectGatherScatterAddrMode both attempt the transformation: base(0) + index(A+splat(B)) => base(B) + index(A) However, this is only safe when index is not implicitly scaled. Differential Revision: https://reviews.llvm.org/D123222
1 parent cacaa44 commit 7a0b897

File tree

3 files changed

+27
-22
lines changed

3 files changed

+27
-22
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10426,14 +10426,19 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
1042610426
TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
1042710427
}
1042810428

10429-
bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) {
10429+
bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
10430+
SelectionDAG &DAG) {
1043010431
if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
1043110432
return false;
1043210433

10434+
// Only perform the transformation when existing operands can be reused.
10435+
if (IndexIsScaled)
10436+
return false;
10437+
1043310438
// For now we check only the LHS of the add.
1043410439
SDValue LHS = Index.getOperand(0);
1043510440
SDValue SplatVal = DAG.getSplatValue(LHS);
10436-
if (!SplatVal)
10441+
if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType())
1043710442
return false;
1043810443

1043910444
BasePtr = SplatVal;
@@ -10481,7 +10486,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
1048110486
if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
1048210487
return Chain;
1048310488

10484-
if (refineUniformBase(BasePtr, Index, DAG)) {
10489+
if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
1048510490
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
1048610491
return DAG.getMaskedScatter(
1048710492
DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
@@ -10576,7 +10581,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
1057610581
if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
1057710582
return CombineTo(N, PassThru, MGT->getChain());
1057810583

10579-
if (refineUniformBase(BasePtr, Index, DAG)) {
10584+
if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
1058010585
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
1058110586
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
1058210587
MGT->getMemoryVT(), DL, Ops,

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4656,10 +4656,10 @@ bool getGatherScatterIndexIsExtended(SDValue Index) {
46564656
// VECTOR + IMMEDIATE:
46574657
// getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices)
46584658
// -> getelementptr #x, <vscale x N x T> %indices
4659-
void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT,
4660-
unsigned &Opcode, bool IsGather,
4661-
SelectionDAG &DAG) {
4662-
if (!isNullConstant(BasePtr))
4659+
void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index,
4660+
bool IsScaled, EVT MemVT, unsigned &Opcode,
4661+
bool IsGather, SelectionDAG &DAG) {
4662+
if (!isNullConstant(BasePtr) || IsScaled)
46634663
return;
46644664

46654665
// FIXME: This will not match for fixed vector type codegen as the nodes in
@@ -4789,7 +4789,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
47894789
Index = Index.getOperand(0);
47904790

47914791
unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend);
4792-
selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
4792+
selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
47934793
/*isGather=*/true, DAG);
47944794

47954795
if (ExtType == ISD::SEXTLOAD)
@@ -4898,7 +4898,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
48984898
Index = Index.getOperand(0);
48994899

49004900
unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend);
4901-
selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
4901+
selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
49024902
/*isGather=*/false, DAG);
49034903

49044904
if (IsFixedLength) {

llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -343,12 +343,13 @@ define <vscale x 2 x i64> @masked_gather_nxv2i64_const_with_vec_offsets(<vscale
343343
ret <vscale x 2 x i64> %data
344344
}
345345

346-
; TODO: The generated code is wrong because we've lost the scaling applied to
347-
; %scalar_offset when it's used to calculate %ptrs.
348346
define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x 2 x i64> %vector_offsets, i64 %scalar_offset, <vscale x 2 x i1> %pg) #0 {
349347
; CHECK-LABEL: masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets:
350348
; CHECK: // %bb.0:
351-
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0, z0.d, lsl #3]
349+
; CHECK-NEXT: mov x8, xzr
350+
; CHECK-NEXT: mov z1.d, x0
351+
; CHECK-NEXT: add z0.d, z0.d, z1.d
352+
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3]
352353
; CHECK-NEXT: ret
353354
%scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 %scalar_offset, i64 0
354355
%scalar_offset.splat = shufflevector <vscale x 2 x i64> %scalar_offset.ins, <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
@@ -358,12 +359,11 @@ define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with_vec_plus_scalar_offse
358359
ret <vscale x 2 x i64> %data
359360
}
360361

361-
; TODO: The generated code is wrong because we've lost the scaling applied to
362-
; constant scalar offset (i.e. i64 1) when it's used to calculate %ptrs.
363362
define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with__vec_plus_imm_offsets(<vscale x 2 x i64> %vector_offsets, <vscale x 2 x i1> %pg) #0 {
364363
; CHECK-LABEL: masked_gather_nxv2i64_null_with__vec_plus_imm_offsets:
365364
; CHECK: // %bb.0:
366-
; CHECK-NEXT: mov w8, #1
365+
; CHECK-NEXT: mov x8, xzr
366+
; CHECK-NEXT: add z0.d, z0.d, #1 // =0x1
367367
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3]
368368
; CHECK-NEXT: ret
369369
%scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 1, i64 0
@@ -425,12 +425,13 @@ define void @masked_scatter_nxv2i64_const_with_vec_offsets(<vscale x 2 x i64> %v
425425
ret void
426426
}
427427

428-
; TODO: The generated code is wrong because we've lost the scaling applied to
429-
; %scalar_offset when it's used to calculate %ptrs.
430428
define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x 2 x i64> %vector_offsets, i64 %scalar_offset, <vscale x 2 x i1> %pg, <vscale x 2 x i64> %data) #0 {
431429
; CHECK-LABEL: masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets:
432430
; CHECK: // %bb.0:
433-
; CHECK-NEXT: st1d { z1.d }, p0, [x0, z0.d, lsl #3]
431+
; CHECK-NEXT: mov x8, xzr
432+
; CHECK-NEXT: mov z2.d, x0
433+
; CHECK-NEXT: add z0.d, z0.d, z2.d
434+
; CHECK-NEXT: st1d { z1.d }, p0, [x8, z0.d, lsl #3]
434435
; CHECK-NEXT: ret
435436
%scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 %scalar_offset, i64 0
436437
%scalar_offset.splat = shufflevector <vscale x 2 x i64> %scalar_offset.ins, <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
@@ -440,12 +441,11 @@ define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x
440441
ret void
441442
}
442443

443-
; TODO: The generated code is wrong because we've lost the scaling applied to
444-
; constant scalar offset (i.e. i64 1) when it's used to calculate %ptrs.
445444
define void @masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets(<vscale x 2 x i64> %vector_offsets, <vscale x 2 x i1> %pg, <vscale x 2 x i64> %data) #0 {
446445
; CHECK-LABEL: masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets:
447446
; CHECK: // %bb.0:
448-
; CHECK-NEXT: mov w8, #1
447+
; CHECK-NEXT: mov x8, xzr
448+
; CHECK-NEXT: add z0.d, z0.d, #1 // =0x1
449449
; CHECK-NEXT: st1d { z1.d }, p0, [x8, z0.d, lsl #3]
450450
; CHECK-NEXT: ret
451451
%scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 1, i64 0

0 commit comments

Comments
 (0)