Skip to content

[AArch64][SME] Remove implicit-def's on smstart #69012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7376,6 +7376,22 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
return ZExtBool;
}

void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
SDNode *Node) const {
// Live-in physreg copies that are glued to SMSTART are applied as
// implicit-def's in the InstrEmitter. Here we remove them, allowing the
// register allocator to pass call args in callee saved regs, without extra
// copies to avoid these fake clobbers of actually-preserved GPRs.
if (MI.getOpcode() == AArch64::MSRpstatesvcrImm1 ||
MI.getOpcode() == AArch64::MSRpstatePseudo)
for (unsigned I = MI.getNumOperands() - 1; I > 0; --I)
if (MachineOperand &MO = MI.getOperand(I);
MO.isReg() && MO.isImplicit() && MO.isDef() &&
(AArch64::GPR32RegClass.contains(MO.getReg()) ||
AArch64::GPR64RegClass.contains(MO.getReg())))
MI.removeOperand(I);
}

SDValue AArch64TargetLowering::changeStreamingMode(
SelectionDAG &DAG, SDLoc DL, bool Enable,
SDValue Chain, SDValue InGlue, SDValue PStateSM, bool Entry) const {
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,9 @@ class AArch64TargetLowering : public TargetLowering {
const SDLoc &DL, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &InVals) const override;

void AdjustInstrPostInstrSelection(MachineInstr &MI,
SDNode *Node) const override;

SDValue LowerCall(CallLoweringInfo & /*CLI*/,
SmallVectorImpl<SDValue> &InVals) const override;

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def : Pat<(i64 (int_aarch64_sme_get_tpidr2)),
def MSRpstatePseudo :
Pseudo<(outs),
(ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>,
Sched<[WriteSys]>;
Sched<[WriteSys]> {
let hasPostISelHook = 1;
}

def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
(MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/SMEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def MSRpstatesvcrImm1
let Inst{11-9} = pstatefield;
let Inst{8} = imm;
let Inst{7-5} = 0b011; // op2
let hasPostISelHook = 1;
}

def : InstAlias<"smstart", (MSRpstatesvcrImm1 0b011, 0b1)>;
Expand Down
53 changes: 53 additions & 0 deletions llvm/test/CodeGen/AArch64/sme-streaming-compatible-interface.ll
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,56 @@ define void @disable_tailcallopt() "aarch64_pstate_sm_compatible" nounwind {
tail call void @normal_callee();
ret void;
}

define void @call_to_non_streaming_pass_args(ptr nocapture noundef readnone %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2) "aarch64_pstate_sm_compatible" {
; CHECK-LABEL: call_to_non_streaming_pass_args:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub sp, sp, #112
; CHECK-NEXT: stp d15, d14, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: stp d13, d12, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: stp d9, d8, [sp, #80] // 16-byte Folded Spill
; CHECK-NEXT: stp x30, x19, [sp, #96] // 16-byte Folded Spill
; CHECK-NEXT: .cfi_def_cfa_offset 112
; CHECK-NEXT: .cfi_offset w19, -8
; CHECK-NEXT: .cfi_offset w30, -16
; CHECK-NEXT: .cfi_offset b8, -24
; CHECK-NEXT: .cfi_offset b9, -32
; CHECK-NEXT: .cfi_offset b10, -40
; CHECK-NEXT: .cfi_offset b11, -48
; CHECK-NEXT: .cfi_offset b12, -56
; CHECK-NEXT: .cfi_offset b13, -64
; CHECK-NEXT: .cfi_offset b14, -72
; CHECK-NEXT: .cfi_offset b15, -80
; CHECK-NEXT: stp d2, d3, [sp, #16] // 16-byte Folded Spill
; CHECK-NEXT: mov x8, x1
; CHECK-NEXT: mov x9, x0
; CHECK-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x19, x0, #0x1
; CHECK-NEXT: tbz w19, #0, .LBB10_2
; CHECK-NEXT: // %bb.1: // %entry
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB10_2: // %entry
; CHECK-NEXT: ldp s0, s1, [sp, #8] // 8-byte Folded Reload
; CHECK-NEXT: mov x0, x9
; CHECK-NEXT: ldp d2, d3, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: mov x1, x8
; CHECK-NEXT: bl bar
; CHECK-NEXT: tbz w19, #0, .LBB10_4
; CHECK-NEXT: // %bb.3: // %entry
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB10_4: // %entry
; CHECK-NEXT: ldp x30, x19, [sp, #96] // 16-byte Folded Reload
; CHECK-NEXT: ldp d9, d8, [sp, #80] // 16-byte Folded Reload
; CHECK-NEXT: ldp d11, d10, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: ldp d13, d12, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #112
; CHECK-NEXT: ret
entry:
call void @bar(ptr noundef nonnull %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2)
ret void
}

declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 noundef, float noundef, float noundef, double noundef, double noundef)
41 changes: 33 additions & 8 deletions llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,11 @@ define i8 @call_to_non_streaming_pass_sve_objects(ptr nocapture noundef readnone
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-3
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: addvl x9, sp, #2
; CHECK-NEXT: addvl x10, sp, #1
; CHECK-NEXT: mov x11, sp
; CHECK-NEXT: rdsvl x3, #1
; CHECK-NEXT: addvl x0, sp, #2
; CHECK-NEXT: addvl x1, sp, #1
; CHECK-NEXT: mov x2, sp
; CHECK-NEXT: smstop sm
; CHECK-NEXT: mov x0, x9
; CHECK-NEXT: mov x1, x10
; CHECK-NEXT: mov x2, x11
; CHECK-NEXT: mov x3, x8
; CHECK-NEXT: bl foo
; CHECK-NEXT: smstart sm
; CHECK-NEXT: ptrue p0.b
Expand All @@ -400,8 +396,37 @@ entry:
ret i8 %vecext
}

define void @call_to_non_streaming_pass_args(ptr nocapture noundef readnone %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2) #0 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a similar test to sme-streaming-compatible-interface.ll?

; CHECK-LABEL: call_to_non_streaming_pass_args:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub sp, sp, #112
; CHECK-NEXT: stp d15, d14, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: stp d13, d12, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: stp d9, d8, [sp, #80] // 16-byte Folded Spill
; CHECK-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
; CHECK-NEXT: stp d2, d3, [sp, #16] // 16-byte Folded Spill
; CHECK-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
; CHECK-NEXT: smstop sm
; CHECK-NEXT: ldp s0, s1, [sp, #8] // 8-byte Folded Reload
; CHECK-NEXT: ldp d2, d3, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: bl bar
; CHECK-NEXT: smstart sm
; CHECK-NEXT: ldp d9, d8, [sp, #80] // 16-byte Folded Reload
; CHECK-NEXT: ldr x30, [sp, #96] // 8-byte Folded Reload
; CHECK-NEXT: ldp d11, d10, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: ldp d13, d12, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #112
; CHECK-NEXT: ret
entry:
call void @bar(ptr noundef nonnull %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2)
ret void
}

declare i64 @llvm.aarch64.sme.cntsb()

declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef)
declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 noundef, float noundef, float noundef, double noundef, double noundef)

attributes #0 = { nounwind vscale_range(1,16) "aarch64_pstate_sm_enabled" }