Skip to content

[AArch64][SME] Conditionally do smstart/smstop #77113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 18, 2024

Conversation

MDevereau
Copy link
Contributor

This patch adds conditional enabling/disabling of streaming mode for functions which have both the aarch64_pstate_sm_compatible and aarch64_pstate_sm_body attributes.

This combination allows callees to determine if switching streaming mode is required instead of relying on the caller.

This patch adds conditional enabling/disabling of streaming mode
for functions which have both the aarch64_pstate_sm_compatible
and aarch64_pstate_sm_body attributes.

This combination allows callees to determine
if switching streaming mode is required instead
of relying on the caller.
@MDevereau MDevereau marked this pull request as ready for review January 5, 2024 16:01
@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Matthew Devereau (MDevereau)

Changes

This patch adds conditional enabling/disabling of streaming mode for functions which have both the aarch64_pstate_sm_compatible and aarch64_pstate_sm_body attributes.

This combination allows callees to determine if switching streaming mode is required instead of relying on the caller.


Full diff: https://github.com/llvm/llvm-project/pull/77113.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+26-10)
  • (modified) llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h (+5)
  • (added) llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible.ll (+177)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 50658a855cfb37..72ae574efb6191 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4855,14 +4855,14 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
 SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
                                            SMEAttrs Attrs, SDLoc DL,
                                            EVT VT) const {
-  if (Attrs.hasStreamingInterfaceOrBody())
+  if (Attrs.hasStreamingInterfaceOrBody() &&
+      !Attrs.hasStreamingCompatibleInterface())
     return DAG.getConstant(1, DL, VT);
 
-  if (Attrs.hasNonStreamingInterfaceAndBody())
+  if (Attrs.hasNonStreamingInterfaceAndBody() &&
+      !Attrs.hasStreamingCompatibleInterface())
     return DAG.getConstant(0, DL, VT);
 
-  assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface");
-
   SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
                                          getPointerTy(DAG.getDataLayout()));
   Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
@@ -6888,9 +6888,18 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
   // Insert the SMSTART if this is a locally streaming function and
   // make sure it is Glued to the last CopyFromReg value.
   if (IsLocallyStreaming) {
-    Chain =
-        changeStreamingMode(DAG, DL, /*Enable*/ true, DAG.getRoot(), Glue,
-                            DAG.getConstant(0, DL, MVT::i64), /*Entry*/ true);
+    SDValue PStateSM;
+    if (Attrs.hasStreamingCompatibleInterface()) {
+      PStateSM = getPStateSM(DAG, Chain, Attrs, DL, MVT::i64);
+      Register Reg = MF.getRegInfo().createVirtualRegister(
+          getRegClassFor(PStateSM.getValueType().getSimpleVT()));
+      FuncInfo->setPStateSMReg(Reg);
+      Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
+    } else {
+      PStateSM = DAG.getConstant(0, DL, MVT::i64);
+    }
+    Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
+                                /*Entry*/ true);
 
     // Ensure that the SMSTART happens after the CopyWithChain such that its
     // chain result is used.
@@ -8201,9 +8210,16 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   // Emit SMSTOP before returning from a locally streaming function
   SMEAttrs FuncAttrs(MF.getFunction());
   if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
-    Chain = changeStreamingMode(
-        DAG, DL, /*Enable*/ false, Chain, /*Glue*/ SDValue(),
-        DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
+    SDValue PStateSM;
+    if (FuncAttrs.hasStreamingCompatibleInterface()) {
+      Register Reg = FuncInfo->getPStateSMReg();
+      assert(Reg.isValid() && "PStateSM Register is invalid");
+      PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
+    } else {
+      PStateSM = DAG.getConstant(1, DL, MVT::i64);
+    }
+    Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+                                /*Glue*/ SDValue(), PStateSM, /*Entry*/ true);
     Glue = Chain.getValue(1);
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index cd4a18bfbc23a8..096fde364a2dcc 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -208,6 +208,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
 
   int64_t StackProbeSize = 0;
 
+  Register PStateSMReg = MCRegister::NoRegister;
+
 public:
   AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
 
@@ -216,6 +218,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
         const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
       const override;
 
+  Register getPStateSMReg() const { return PStateSMReg; };
+  void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
+
   bool isSVECC() const { return IsSVECC; };
   void setIsSVECC(bool s) { IsSVECC = s; };
 
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible.ll b/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible.ll
new file mode 100644
index 00000000000000..16375041e4298f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible.ll
@@ -0,0 +1,177 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+declare void @normal_callee();
+declare void @streaming_callee() "aarch64_pstate_sm_enabled";
+declare void @streaming_compatible_callee() "aarch64_pstate_sm_compatible";
+
+define float @sm_body_sm_compatible_simple() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: sm_body_sm_compatible_simple:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 80
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    .cfi_offset b8, -24
+; CHECK-NEXT:    .cfi_offset b9, -32
+; CHECK-NEXT:    .cfi_offset b10, -40
+; CHECK-NEXT:    .cfi_offset b11, -48
+; CHECK-NEXT:    .cfi_offset b12, -56
+; CHECK-NEXT:    .cfi_offset b13, -64
+; CHECK-NEXT:    .cfi_offset b14, -72
+; CHECK-NEXT:    .cfi_offset b15, -80
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x8, x0, #0x1
+; CHECK-NEXT:    tbnz w8, #0, .LBB0_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    tbz w8, #0, .LBB0_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB0_4:
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    fmov s0, wzr
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  ret float zeroinitializer
+}
+
+define void @sm_body_caller_sm_compatible_caller_normal_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: sm_body_caller_sm_compatible_caller_normal_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    stp x20, x19, [sp, #80] // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 96
+; CHECK-NEXT:    .cfi_offset w19, -8
+; CHECK-NEXT:    .cfi_offset w20, -16
+; CHECK-NEXT:    .cfi_offset w30, -32
+; CHECK-NEXT:    .cfi_offset b8, -40
+; CHECK-NEXT:    .cfi_offset b9, -48
+; CHECK-NEXT:    .cfi_offset b10, -56
+; CHECK-NEXT:    .cfi_offset b11, -64
+; CHECK-NEXT:    .cfi_offset b12, -72
+; CHECK-NEXT:    .cfi_offset b13, -80
+; CHECK-NEXT:    .cfi_offset b14, -88
+; CHECK-NEXT:    .cfi_offset b15, -96
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbnz w19, #0, .LBB1_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x20, x0, #0x1
+; CHECK-NEXT:    tbz w20, #0, .LBB1_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB1_4:
+; CHECK-NEXT:    bl normal_callee
+; CHECK-NEXT:    tbz w20, #0, .LBB1_6
+; CHECK-NEXT:  // %bb.5:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB1_6:
+; CHECK-NEXT:    tbz w19, #0, .LBB1_8
+; CHECK-NEXT:  // %bb.7:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB1_8:
+; CHECK-NEXT:    ldp x20, x19, [sp, #80] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @normal_callee()
+  ret void
+}
+
+define void @sm_body_caller_sm_compatible_caller_streaming_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: sm_body_caller_sm_compatible_caller_streaming_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 80
+; CHECK-NEXT:    .cfi_offset w19, -8
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    .cfi_offset b8, -24
+; CHECK-NEXT:    .cfi_offset b9, -32
+; CHECK-NEXT:    .cfi_offset b10, -40
+; CHECK-NEXT:    .cfi_offset b11, -48
+; CHECK-NEXT:    .cfi_offset b12, -56
+; CHECK-NEXT:    .cfi_offset b13, -64
+; CHECK-NEXT:    .cfi_offset b14, -72
+; CHECK-NEXT:    .cfi_offset b15, -80
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbnz w19, #0, .LBB2_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB2_2:
+; CHECK-NEXT:    bl streaming_callee
+; CHECK-NEXT:    tbz w19, #0, .LBB2_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB2_4:
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @streaming_callee()
+  ret void
+}
+
+define void @sm_body_caller_sm_compatible_caller_streaming_compatible_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: sm_body_caller_sm_compatible_caller_streaming_compatible_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 80
+; CHECK-NEXT:    .cfi_offset w19, -8
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    .cfi_offset b8, -24
+; CHECK-NEXT:    .cfi_offset b9, -32
+; CHECK-NEXT:    .cfi_offset b10, -40
+; CHECK-NEXT:    .cfi_offset b11, -48
+; CHECK-NEXT:    .cfi_offset b12, -56
+; CHECK-NEXT:    .cfi_offset b13, -64
+; CHECK-NEXT:    .cfi_offset b14, -72
+; CHECK-NEXT:    .cfi_offset b15, -80
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbnz w19, #0, .LBB3_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB3_2:
+; CHECK-NEXT:    bl streaming_compatible_callee
+; CHECK-NEXT:    tbz w19, #0, .LBB3_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB3_4:
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @streaming_compatible_callee()
+  ret void
+}

; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: bl __arm_sme_state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This second __arm_sme_state seems redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,177 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test where with multiple basic blocks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ret void
}

define void @sm_body_caller_sm_compatible_caller_streaming_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see much value added by this test. Can you remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ret void
}

define void @sm_body_caller_sm_compatible_caller_streaming_compatible_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this test, I think it can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

declare void @streaming_callee() "aarch64_pstate_sm_enabled";
declare void @streaming_compatible_callee() "aarch64_pstate_sm_compatible";

define float @sm_body_sm_compatible_simple() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you add nounwind to these tests so that it avoids generating the .cfi_offset lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

; CHECK-NEXT: .cfi_offset b15, -80
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x19, x0, #0x1
; CHECK-NEXT: tbnz w19, #0, .LBB1_2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
; CHECK-NEXT: tbnz w19, #0, .LBB1_2
; CHECK-NEXT: tbz w8, #0, .LBB0_2

I think this is the wrong way around, tbnz tests the given bit and branches when that bit is non-zero. When PSTATE.SM = 1 on entry, the bit will be 1. This code will now branch to .LBB1_2 and will avoid the smstart sm for the body.

Likewise, below the tbz w19, #0 tests if PSTATE.SM bit was 0 on entry the function and branches to .LBB1_4 if it was, and therefore does not invoke the smstop sm.

Copy link
Contributor Author

@MDevereau MDevereau Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the wrong way around, tbnz tests the given bit and branches when that bit is non-zero. When PSTATE.SM = 1 on entry, the bit will be 1. This code will now branch to .LBB1_2 and will avoid the smstart sm for the body.

That's what we're trying to do, right? If we know we're in streaming mode on function entry we don't need to start it and can jump straight to the streaming mode body. If we know we're not in streaming mode we start streaming mode to execute the streaming body in streaming mode.

Likewise, below the tbz w19, #0 tests if PSTATE.SM bit was 0 on entry the function and branches to .LBB1_4 if it was, and therefore does not invoke the smstop sm.

I don't think this is the wrong way around, but it is incorrect. I think what's happened has LowerCall emits an smstop/smstart pair to comply with the sm_body attribute. But it's still using the initial __arm_sme_state call to gauge what mode to exit in.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what we're trying to do, right?

Yes, you're right. I got this one the wrong way around myself :)

The return is indeed wrong. That should also be using a tbnz, because if PSTATE.SM was 0 on entry, it will be forced to do an smstop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should also be using a tbnz

Yeah that's correct, the call has nothing to do with it on second thought. I'll fix this then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdesmalen-arm Should be sorted now.

Comment on lines 4855 to 4857
SDValue AArch64TargetLowering::getPStateSM(
SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs, SDLoc DL, EVT VT,
bool AllowStreamingCompatibleInterface) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this extra operand rather confusing.

It seems that getPStateSM has only two uses. What if you rename the function to getRuntimePStateSM and change the use in LowerCall to:

if (CallerAttrs.hasStreamingInterfaceOrBody())
  PStateSM = DAG.getConstant(1, DL, MVT::i64);
else if (CallerAttrs.hasNonStreamingInterface())
  PStateSM = DAG.getConstant(0, DL, MVT::i64);
else
  PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);

That would simplify this function a bit more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the extra operand and separated out the DAG.getConstants from this function.

Comment on lines 1292 to 1293
// Returns the runtime value for PSTATE.SM. When the function is streaming-
// compatible, this generates a call to __arm_sme_state.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Returns the runtime value for PSTATE.SM. When the function is streaming-
// compatible, this generates a call to __arm_sme_state.
// Returns the runtime value for PSTATE.SM by generating a call to __arm_sme_state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -208,6 +208,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {

int64_t StackProbeSize = 0;

Register PStateSMReg = MCRegister::NoRegister;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you add a comment describing what this Register holds and when this value is defined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 68 to 77
// If the caller has a streaming body and streaming compatible interface,
// we will have already conditionally enabled streaming mode on function
// entry. We need to disable streaming mode when a callee does not have A
// streaming interface, body, or streaming compatible interface.
if (hasStreamingBody() && hasStreamingCompatibleInterface())
return (!Callee.hasStreamingInterfaceOrBody() &&
!Callee.hasStreamingCompatibleInterface())
? std::optional<bool>(false)
: std::nullopt;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I comment out this code, the tests still pass.

I'm also not sure what this change is supposed to do. Isn't this case already covered by:

  if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
    return std::nullopt;

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this had an interaction with the some of earlier tests that were removed and certain caller/callee combinations weren't correctly return as needing a streaming-mode change, or things we're incorrectly falling through to getPStateSM which has now been changed. This whole function is quite weird in general, I don't really understand the advantage a 3rd state provided by the std::optional class brings. Either it needs a streaming mode change or it doesn't. Either way this snippet just be removed for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The three states are to describe:

  • std::nullopt = No change needed
  • std::optional<bool>(true) = Change needed, smstart
  • std::optional<bool>(false) = Change needed, smstop
    I do agree the interface is confusing and could be better named.

@@ -0,0 +1,124 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Just one final little nit on the name of this file, can you rename it to: sme-streaming-body-streaming-compatible-interface.ll

to sme-streaming-body-streaming-compatible-interface.ll
@MDevereau MDevereau merged commit 51e3d2f into llvm:main Jan 18, 2024
@MDevereau MDevereau deleted the smstartstop branch January 18, 2024 09:17
ampandey-1995 pushed a commit to ampandey-1995/llvm-project that referenced this pull request Jan 19, 2024
This patch adds conditional enabling/disabling of streaming mode for
functions which have both the aarch64_pstate_sm_compatible and
aarch64_pstate_sm_body attributes.

This combination allows callees to determine if switching streaming mode
is required instead of relying on the caller.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants