Skip to content

[AArch64][SME] Conditionally do smstart/smstop #77113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4852,17 +4852,9 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
}

SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
SMEAttrs Attrs, SDLoc DL,
EVT VT) const {
if (Attrs.hasStreamingInterfaceOrBody())
return DAG.getConstant(1, DL, VT);

if (Attrs.hasNonStreamingInterfaceAndBody())
return DAG.getConstant(0, DL, VT);

assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface");

SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
SDValue Chain, SDLoc DL,
EVT VT) const {
SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
getPointerTy(DAG.getDataLayout()));
Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
Expand Down Expand Up @@ -6888,9 +6880,18 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
// Insert the SMSTART if this is a locally streaming function and
// make sure it is Glued to the last CopyFromReg value.
if (IsLocallyStreaming) {
Chain =
changeStreamingMode(DAG, DL, /*Enable*/ true, DAG.getRoot(), Glue,
DAG.getConstant(0, DL, MVT::i64), /*Entry*/ true);
SDValue PStateSM;
if (Attrs.hasStreamingCompatibleInterface()) {
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
Register Reg = MF.getRegInfo().createVirtualRegister(
getRegClassFor(PStateSM.getValueType().getSimpleVT()));
FuncInfo->setPStateSMReg(Reg);
Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
} else {
PStateSM = DAG.getConstant(0, DL, MVT::i64);
}
Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
/*Entry*/ true);

// Ensure that the SMSTART happens after the CopyWithChain such that its
// chain result is used.
Expand Down Expand Up @@ -7648,7 +7649,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
std::optional<bool> RequiresSMChange =
CallerAttrs.requiresSMChange(CalleeAttrs);
if (RequiresSMChange) {
PStateSM = getPStateSM(DAG, Chain, CallerAttrs, DL, MVT::i64);
if (CallerAttrs.hasStreamingInterfaceOrBody())
PStateSM = DAG.getConstant(1, DL, MVT::i64);
else if (CallerAttrs.hasNonStreamingInterface())
PStateSM = DAG.getConstant(0, DL, MVT::i64);
else
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
OptimizationRemarkEmitter ORE(&MF.getFunction());
ORE.emit([&]() {
auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
Expand Down Expand Up @@ -8201,9 +8207,17 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
// Emit SMSTOP before returning from a locally streaming function
SMEAttrs FuncAttrs(MF.getFunction());
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
Chain = changeStreamingMode(
DAG, DL, /*Enable*/ false, Chain, /*Glue*/ SDValue(),
DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
if (FuncAttrs.hasStreamingCompatibleInterface()) {
Register Reg = FuncInfo->getPStateSMReg();
assert(Reg.isValid() && "PStateSM Register is invalid");
SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
Chain =
changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
/*Glue*/ SDValue(), PStateSM, /*Entry*/ false);
} else
Chain = changeStreamingMode(
DAG, DL, /*Enable*/ false, Chain,
/*Glue*/ SDValue(), DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
Glue = Chain.getValue(1);
}

Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1289,10 +1289,10 @@ class AArch64TargetLowering : public TargetLowering {
// This function does not handle predicate bitcasts.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;

// Returns the runtime value for PSTATE.SM. When the function is streaming-
// compatible, this generates a call to __arm_sme_state.
SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs,
SDLoc DL, EVT VT) const;
// Returns the runtime value for PSTATE.SM by generating a call to
// __arm_sme_state.
SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
EVT VT) const;

bool preferScalarizeSplat(SDNode *N) const override;

Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {

int64_t StackProbeSize = 0;

// Holds a register containing pstate.sm. This is set
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you add a comment describing what this Register holds and when this value is defined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);

Expand All @@ -216,6 +220,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
const override;

Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };

bool isSVECC() const { return IsSVECC; };
void setIsSVECC(bool s) { IsSVECC = s; };

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s

declare void @normal_callee();
declare void @streaming_callee() "aarch64_pstate_sm_enabled";
declare void @streaming_compatible_callee() "aarch64_pstate_sm_compatible";

define float @sm_body_sm_compatible_simple() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
; CHECK-LABEL: sm_body_sm_compatible_simple:
; CHECK: // %bb.0:
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x8, x0, #0x1
; CHECK-NEXT: tbnz w8, #0, .LBB0_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB0_2:
; CHECK-NEXT: tbnz w8, #0, .LBB0_4
; CHECK-NEXT: // %bb.3:
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB0_4:
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: fmov s0, wzr
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
; CHECK-NEXT: ret
ret float zeroinitializer
}

define void @sm_body_caller_sm_compatible_caller_normal_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
; CHECK-LABEL: sm_body_caller_sm_compatible_caller_normal_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x19, x0, #0x1
; CHECK-NEXT: tbnz w19, #0, .LBB1_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: smstop sm
; CHECK-NEXT: bl normal_callee
; CHECK-NEXT: smstart sm
; CHECK-NEXT: tbnz w19, #0, .LBB1_4
; CHECK-NEXT: // %bb.3:
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB1_4:
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @normal_callee()
ret void
}

; Function Attrs: nounwind uwtable vscale_range(1,16)
define void @streaming_body_and_streaming_compatible_interface_multi_basic_block(i32 noundef %x) "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
; CHECK-LABEL: streaming_body_and_streaming_compatible_interface_multi_basic_block:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
; CHECK-NEXT: mov w8, w0
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x19, x0, #0x1
; CHECK-NEXT: tbnz w19, #0, .LBB2_2
; CHECK-NEXT: // %bb.1: // %entry
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB2_2: // %entry
; CHECK-NEXT: cbz w8, .LBB2_6
; CHECK-NEXT: // %bb.3: // %if.else
; CHECK-NEXT: bl streaming_compatible_callee
; CHECK-NEXT: tbnz w19, #0, .LBB2_5
; CHECK-NEXT: // %bb.4: // %if.else
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB2_5: // %if.else
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
; CHECK-NEXT: ret
; CHECK-NEXT: .LBB2_6: // %if.then
; CHECK-NEXT: smstop sm
; CHECK-NEXT: bl normal_callee
; CHECK-NEXT: smstart sm
; CHECK-NEXT: tbnz w19, #0, .LBB2_8
; CHECK-NEXT: // %bb.7: // %if.then
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB2_8: // %if.then
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
; CHECK-NEXT: ret
entry:
%cmp = icmp eq i32 %x, 0
br i1 %cmp, label %if.then, label %if.else

if.then: ; preds = %entry
tail call void @normal_callee()
br label %return

if.else: ; preds = %entry
tail call void @streaming_compatible_callee()
br label %return

return: ; preds = %if.else, %if.then
ret void
}