Skip to content

Commit 55e6ba4

Browse files
sdesmalen-armwwwatermiao
authored andcommitted
[SME] Stop RA from coalescing COPY instructions that transcend beyond smstart/smstop. (#78294)
This patch introduces a 'COALESCER_BARRIER' which is a pseudo node that expands to a 'nop', but which stops the register allocator from coalescing a COPY node when its use/def crosses a SMSTART or SMSTOP instruction. For example: %0:fpr64 = COPY killed $d0 undef %2.dsub:zpr = COPY %0 // <- Do not coalesce this COPY ADJCALLSTACKDOWN 0, 0 MSRpstatesvcrImm1 1, 0, csr_aarch64_smstartstop, implicit-def dead $d0 $d0 = COPY killed %0 BL @use_f64, csr_aarch64_aapcs If the COPY would be coalesced, that would lead to: $d0 = COPY killed %0 being replaced by: $d0 = COPY killed %2.dsub which means the whole ZPR reg would be live upto the call, causing the MSRpstatesvcrImm1 (smstop) to spill/reload the ZPR register: str q0, [sp] // 16-byte Folded Spill smstop sm ldr z0, [sp] // 16-byte Folded Reload bl use_f64 which would be incorrect for two reasons: 1. The program may load more data than it has allocated. 2. If there are other SVE objects on the stack, the compiler might use the 'mul vl' addressing modes to access the spill location. By disabling the coalescing, we get the desired results: str d0, [sp, #8] // 8-byte Folded Spill smstop sm ldr d0, [sp, #8] // 8-byte Folded Reload bl use_f64 Signed-off-by: chenmiao <[email protected]> Signed-off-by: chenmiao <[email protected]>
1 parent 7dc4bad commit 55e6ba4

11 files changed

+1777
-36
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,12 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
14831483
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
14841484
return true;
14851485
}
1486+
case AArch64::COALESCER_BARRIER_FPR16:
1487+
case AArch64::COALESCER_BARRIER_FPR32:
1488+
case AArch64::COALESCER_BARRIER_FPR64:
1489+
case AArch64::COALESCER_BARRIER_FPR128:
1490+
MI.eraseFromParent();
1491+
return true;
14861492
}
14871493
return false;
14881494
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,6 +2277,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
22772277
switch ((AArch64ISD::NodeType)Opcode) {
22782278
case AArch64ISD::FIRST_NUMBER:
22792279
break;
2280+
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
22802281
MAKE_CASE(AArch64ISD::SMSTART)
22812282
MAKE_CASE(AArch64ISD::SMSTOP)
22822283
MAKE_CASE(AArch64ISD::RESTORE_ZA)
@@ -6868,13 +6869,18 @@ void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
68686869
}
68696870
}
68706871

6872+
static bool isPassedInFPR(EVT VT) {
6873+
return VT.isFixedLengthVector() ||
6874+
(VT.isFloatingPoint() && !VT.isScalableVector());
6875+
}
6876+
68716877
/// LowerCallResult - Lower the result values of a call into the
68726878
/// appropriate copies out of appropriate physical registers.
68736879
SDValue AArch64TargetLowering::LowerCallResult(
68746880
SDValue Chain, SDValue InGlue, CallingConv::ID CallConv, bool isVarArg,
68756881
const SmallVectorImpl<CCValAssign> &RVLocs, const SDLoc &DL,
68766882
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
6877-
SDValue ThisVal) const {
6883+
SDValue ThisVal, bool RequiresSMChange) const {
68786884
DenseMap<unsigned, SDValue> CopiedRegs;
68796885
// Copy all of the result registers out of their specified physreg.
68806886
for (unsigned i = 0; i != RVLocs.size(); ++i) {
@@ -6919,6 +6925,10 @@ SDValue AArch64TargetLowering::LowerCallResult(
69196925
break;
69206926
}
69216927

6928+
if (RequiresSMChange && isPassedInFPR(VA.getValVT()))
6929+
Val = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL, Val.getValueType(),
6930+
Val);
6931+
69226932
InVals.push_back(Val);
69236933
}
69246934

@@ -7596,6 +7606,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
75967606
return ArgReg.Reg == VA.getLocReg();
75977607
});
75987608
} else {
7609+
// Add an extra level of indirection for streaming mode changes by
7610+
// using a pseudo copy node that cannot be rematerialised between a
7611+
// smstart/smstop and the call by the simple register coalescer.
7612+
if (RequiresSMChange && isPassedInFPR(Arg.getValueType()))
7613+
Arg = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
7614+
Arg.getValueType(), Arg);
75997615
RegsToPass.emplace_back(VA.getLocReg(), Arg);
76007616
RegsUsed.insert(VA.getLocReg());
76017617
const TargetOptions &Options = DAG.getTarget().Options;
@@ -7829,9 +7845,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
78297845

78307846
// Handle result values, copying them out of physregs into vregs that we
78317847
// return.
7832-
SDValue Result = LowerCallResult(Chain, InGlue, CallConv, IsVarArg, RVLocs,
7833-
DL, DAG, InVals, IsThisReturn,
7834-
IsThisReturn ? OutVals[0] : SDValue());
7848+
SDValue Result = LowerCallResult(
7849+
Chain, InGlue, CallConv, IsVarArg, RVLocs, DL, DAG, InVals, IsThisReturn,
7850+
IsThisReturn ? OutVals[0] : SDValue(), RequiresSMChange);
78357851

78367852
if (!Ins.empty())
78377853
InGlue = Result.getValue(Result->getNumValues() - 1);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ enum NodeType : unsigned {
5858

5959
CALL_BTI, // Function call followed by a BTI instruction.
6060

61+
COALESCER_BARRIER,
62+
6163
SMSTART,
6264
SMSTOP,
6365
RESTORE_ZA,
@@ -971,7 +973,7 @@ class AArch64TargetLowering : public TargetLowering {
971973
const SmallVectorImpl<CCValAssign> &RVLocs,
972974
const SDLoc &DL, SelectionDAG &DAG,
973975
SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
974-
SDValue ThisVal) const;
976+
SDValue ThisVal, bool RequiresSMChange) const;
975977

976978
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
977979
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ bool AArch64RegisterInfo::shouldCoalesce(
987987
MachineInstr *MI, const TargetRegisterClass *SrcRC, unsigned SubReg,
988988
const TargetRegisterClass *DstRC, unsigned DstSubReg,
989989
const TargetRegisterClass *NewRC, LiveIntervals &LIS) const {
990+
MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
991+
990992
if (MI->isCopy() &&
991993
((DstRC->getID() == AArch64::GPR64RegClassID) ||
992994
(DstRC->getID() == AArch64::GPR64commonRegClassID)) &&
@@ -995,5 +997,38 @@ bool AArch64RegisterInfo::shouldCoalesce(
995997
// which implements a 32 to 64 bit zero extension
996998
// which relies on the upper 32 bits being zeroed.
997999
return false;
1000+
1001+
auto IsCoalescerBarrier = [](const MachineInstr &MI) {
1002+
switch (MI.getOpcode()) {
1003+
case AArch64::COALESCER_BARRIER_FPR16:
1004+
case AArch64::COALESCER_BARRIER_FPR32:
1005+
case AArch64::COALESCER_BARRIER_FPR64:
1006+
case AArch64::COALESCER_BARRIER_FPR128:
1007+
return true;
1008+
default:
1009+
return false;
1010+
}
1011+
};
1012+
1013+
// For calls that temporarily have to toggle streaming mode as part of the
1014+
// call-sequence, we need to be more careful when coalescing copy instructions
1015+
// so that we don't end up coalescing the NEON/FP result or argument register
1016+
// with a whole Z-register, such that after coalescing the register allocator
1017+
// will try to spill/reload the entire Z register.
1018+
//
1019+
// We do this by checking if the node has any defs/uses that are
1020+
// COALESCER_BARRIER pseudos. These are 'nops' in practice, but they exist to
1021+
// instruct the coalescer to avoid coalescing the copy.
1022+
if (MI->isCopy() && SubReg != DstSubReg &&
1023+
(AArch64::ZPRRegClass.hasSubClassEq(DstRC) ||
1024+
AArch64::ZPRRegClass.hasSubClassEq(SrcRC))) {
1025+
unsigned SrcReg = MI->getOperand(1).getReg();
1026+
if (any_of(MRI.def_instructions(SrcReg), IsCoalescerBarrier))
1027+
return false;
1028+
unsigned DstReg = MI->getOperand(0).getReg();
1029+
if (any_of(MRI.use_nodbg_instructions(DstReg), IsCoalescerBarrier))
1030+
return false;
1031+
}
1032+
9981033
return true;
9991034
}

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
2222
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
2323
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
2424
SDNPOptInGlue]>;
25+
def AArch64CoalescerBarrier
26+
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, []>;
2527

2628
//===----------------------------------------------------------------------===//
2729
// Instruction naming conventions.
@@ -183,6 +185,26 @@ def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
183185
(MSR 0xde85, GPR64:$val)>;
184186
def : Pat<(i64 (int_aarch64_sme_get_tpidr2)),
185187
(MRS 0xde85)>;
188+
189+
multiclass CoalescerBarrierPseudo<RegisterClass rc, list<ValueType> vts> {
190+
def NAME : Pseudo<(outs rc:$dst), (ins rc:$src), []>, Sched<[]> {
191+
let Constraints = "$dst = $src";
192+
}
193+
foreach vt = vts in {
194+
def : Pat<(vt (AArch64CoalescerBarrier (vt rc:$src))),
195+
(!cast<Instruction>(NAME) rc:$src)>;
196+
}
197+
}
198+
199+
multiclass CoalescerBarriers {
200+
defm _FPR16 : CoalescerBarrierPseudo<FPR16, [bf16, f16]>;
201+
defm _FPR32 : CoalescerBarrierPseudo<FPR32, [f32]>;
202+
defm _FPR64 : CoalescerBarrierPseudo<FPR64, [f64, v8i8, v4i16, v2i32, v1i64, v4f16, v2f32, v1f64, v4bf16]>;
203+
defm _FPR128 : CoalescerBarrierPseudo<FPR128, [f128, v16i8, v8i16, v4i32, v2i64, v8f16, v4f32, v2f64, v8bf16]>;
204+
}
205+
206+
defm COALESCER_BARRIER : CoalescerBarriers;
207+
186208
} // End let Predicates = [HasSME]
187209

188210
// Pseudo to match to smstart/smstop. This expands:

llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ define double @nonstreaming_caller_streaming_callee(double %x) nounwind noinline
2323
; CHECK-FISEL-NEXT: bl streaming_callee
2424
; CHECK-FISEL-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
2525
; CHECK-FISEL-NEXT: smstop sm
26+
; CHECK-FISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
2627
; CHECK-FISEL-NEXT: adrp x8, .LCPI0_0
2728
; CHECK-FISEL-NEXT: ldr d0, [x8, :lo12:.LCPI0_0]
28-
; CHECK-FISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
2929
; CHECK-FISEL-NEXT: fadd d0, d1, d0
3030
; CHECK-FISEL-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
3131
; CHECK-FISEL-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -49,9 +49,9 @@ define double @nonstreaming_caller_streaming_callee(double %x) nounwind noinline
4949
; CHECK-GISEL-NEXT: bl streaming_callee
5050
; CHECK-GISEL-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
5151
; CHECK-GISEL-NEXT: smstop sm
52-
; CHECK-GISEL-NEXT: mov x8, #4631107791820423168
53-
; CHECK-GISEL-NEXT: fmov d0, x8
5452
; CHECK-GISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
53+
; CHECK-GISEL-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
54+
; CHECK-GISEL-NEXT: fmov d0, x8
5555
; CHECK-GISEL-NEXT: fadd d0, d1, d0
5656
; CHECK-GISEL-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
5757
; CHECK-GISEL-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -82,9 +82,9 @@ define double @streaming_caller_nonstreaming_callee(double %x) nounwind noinline
8282
; CHECK-COMMON-NEXT: bl normal_callee
8383
; CHECK-COMMON-NEXT: str d0, [sp, #88] // 8-byte Folded Spill
8484
; CHECK-COMMON-NEXT: smstart sm
85+
; CHECK-COMMON-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
8586
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
8687
; CHECK-COMMON-NEXT: fmov d0, x8
87-
; CHECK-COMMON-NEXT: ldr d1, [sp, #88] // 8-byte Folded Reload
8888
; CHECK-COMMON-NEXT: fadd d0, d1, d0
8989
; CHECK-COMMON-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
9090
; CHECK-COMMON-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -110,14 +110,16 @@ define double @locally_streaming_caller_normal_callee(double %x) nounwind noinli
110110
; CHECK-COMMON-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
111111
; CHECK-COMMON-NEXT: str d0, [sp, #24] // 8-byte Folded Spill
112112
; CHECK-COMMON-NEXT: smstart sm
113+
; CHECK-COMMON-NEXT: ldr d0, [sp, #24] // 8-byte Folded Reload
114+
; CHECK-COMMON-NEXT: str d0, [sp, #24] // 8-byte Folded Spill
113115
; CHECK-COMMON-NEXT: smstop sm
114116
; CHECK-COMMON-NEXT: ldr d0, [sp, #24] // 8-byte Folded Reload
115117
; CHECK-COMMON-NEXT: bl normal_callee
116118
; CHECK-COMMON-NEXT: str d0, [sp, #16] // 8-byte Folded Spill
117119
; CHECK-COMMON-NEXT: smstart sm
118-
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168
119-
; CHECK-COMMON-NEXT: fmov d0, x8
120120
; CHECK-COMMON-NEXT: ldr d1, [sp, #16] // 8-byte Folded Reload
121+
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
122+
; CHECK-COMMON-NEXT: fmov d0, x8
121123
; CHECK-COMMON-NEXT: fadd d0, d1, d0
122124
; CHECK-COMMON-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
123125
; CHECK-COMMON-NEXT: smstop sm
@@ -319,9 +321,9 @@ define fp128 @f128_call_sm(fp128 %a, fp128 %b) "aarch64_pstate_sm_enabled" nounw
319321
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill
320322
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #80] // 16-byte Folded Spill
321323
; CHECK-COMMON-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
322-
; CHECK-COMMON-NEXT: stp q0, q1, [sp] // 32-byte Folded Spill
324+
; CHECK-COMMON-NEXT: stp q1, q0, [sp] // 32-byte Folded Spill
323325
; CHECK-COMMON-NEXT: smstop sm
324-
; CHECK-COMMON-NEXT: ldp q0, q1, [sp] // 32-byte Folded Reload
326+
; CHECK-COMMON-NEXT: ldp q1, q0, [sp] // 32-byte Folded Reload
325327
; CHECK-COMMON-NEXT: bl __addtf3
326328
; CHECK-COMMON-NEXT: str q0, [sp, #16] // 16-byte Folded Spill
327329
; CHECK-COMMON-NEXT: smstart sm
@@ -374,14 +376,15 @@ define double @frem_call_za(double %a, double %b) "aarch64_pstate_za_shared" nou
374376
define float @frem_call_sm(float %a, float %b) "aarch64_pstate_sm_enabled" nounwind {
375377
; CHECK-COMMON-LABEL: frem_call_sm:
376378
; CHECK-COMMON: // %bb.0:
377-
; CHECK-COMMON-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
378-
; CHECK-COMMON-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
379-
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
380-
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
381-
; CHECK-COMMON-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
382-
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #72] // 8-byte Folded Spill
379+
; CHECK-COMMON-NEXT: sub sp, sp, #96
380+
; CHECK-COMMON-NEXT: stp d15, d14, [sp, #16] // 16-byte Folded Spill
381+
; CHECK-COMMON-NEXT: stp d13, d12, [sp, #32] // 16-byte Folded Spill
382+
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #48] // 16-byte Folded Spill
383+
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #64] // 16-byte Folded Spill
384+
; CHECK-COMMON-NEXT: str x30, [sp, #80] // 8-byte Folded Spill
385+
; CHECK-COMMON-NEXT: stp s1, s0, [sp, #8] // 8-byte Folded Spill
383386
; CHECK-COMMON-NEXT: smstop sm
384-
; CHECK-COMMON-NEXT: ldp s0, s1, [sp, #72] // 8-byte Folded Reload
387+
; CHECK-COMMON-NEXT: ldp s1, s0, [sp, #8] // 8-byte Folded Reload
385388
; CHECK-COMMON-NEXT: bl fmodf
386389
; CHECK-COMMON-NEXT: str s0, [sp, #76] // 4-byte Folded Spill
387390
; CHECK-COMMON-NEXT: smstart sm
@@ -408,7 +411,9 @@ define float @frem_call_sm_compat(float %a, float %b) "aarch64_pstate_sm_compati
408411
; CHECK-COMMON-NEXT: stp x30, x19, [sp, #80] // 16-byte Folded Spill
409412
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
410413
; CHECK-COMMON-NEXT: bl __arm_sme_state
414+
; CHECK-COMMON-NEXT: ldp s2, s0, [sp, #8] // 8-byte Folded Reload
411415
; CHECK-COMMON-NEXT: and x19, x0, #0x1
416+
; CHECK-COMMON-NEXT: stp s2, s0, [sp, #8] // 8-byte Folded Spill
412417
; CHECK-COMMON-NEXT: tbz w19, #0, .LBB12_2
413418
; CHECK-COMMON-NEXT: // %bb.1:
414419
; CHECK-COMMON-NEXT: smstop sm

0 commit comments

Comments
 (0)