Skip to content

Commit ea610b0

Browse files
[LLVM][SVE] Honour calling convention when using SVE for fixed length vectors.
1 parent b8e0601 commit ea610b0

File tree

4 files changed

+112
-0
lines changed

4 files changed

+112
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26718,3 +26718,99 @@ bool AArch64TargetLowering::preferScalarizeSplat(SDNode *N) const {
2671826718
unsigned AArch64TargetLowering::getMinimumJumpTableEntries() const {
2671926719
return Subtarget->getMinimumJumpTableEntries();
2672026720
}
26721+
26722+
MVT AArch64TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
26723+
CallingConv::ID CC,
26724+
EVT VT) const {
26725+
bool NonUnitFixedLengthVector =
26726+
VT.isFixedLengthVector() && !VT.getVectorElementCount().isScalar();
26727+
if (!NonUnitFixedLengthVector || !Subtarget->useSVEForFixedLengthVectors())
26728+
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
26729+
26730+
EVT VT1;
26731+
MVT RegisterVT;
26732+
unsigned NumIntermediates;
26733+
getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1, NumIntermediates,
26734+
RegisterVT);
26735+
return RegisterVT;
26736+
}
26737+
26738+
unsigned AArch64TargetLowering::getNumRegistersForCallingConv(
26739+
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
26740+
bool NonUnitFixedLengthVector =
26741+
VT.isFixedLengthVector() && !VT.getVectorElementCount().isScalar();
26742+
if (!NonUnitFixedLengthVector || !Subtarget->useSVEForFixedLengthVectors())
26743+
return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
26744+
26745+
EVT VT1;
26746+
MVT VT2;
26747+
unsigned NumIntermediates;
26748+
return getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1,
26749+
NumIntermediates, VT2);
26750+
}
26751+
26752+
unsigned AArch64TargetLowering::getVectorTypeBreakdownForCallingConv(
26753+
LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT,
26754+
unsigned &NumIntermediates, MVT &RegisterVT) const {
26755+
int NumRegs = TargetLowering::getVectorTypeBreakdownForCallingConv(
26756+
Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
26757+
if (!RegisterVT.isFixedLengthVector() ||
26758+
RegisterVT.getFixedSizeInBits() <= 128)
26759+
return NumRegs;
26760+
26761+
assert(Subtarget->useSVEForFixedLengthVectors() && "Unexpected mode!");
26762+
assert(IntermediateVT == RegisterVT && "Unexpected VT mismatch!");
26763+
assert(RegisterVT.getFixedSizeInBits() % 128 == 0 && "Unexpected size!");
26764+
26765+
// A size mismatch here implies either type promotion or widening and would
26766+
// have resulted in scalarisation if larger vectors had not be available.
26767+
if (RegisterVT.getSizeInBits() * NumRegs != VT.getSizeInBits()) {
26768+
EVT EltTy = VT.getVectorElementType();
26769+
EVT NewVT = EVT::getVectorVT(Context, EltTy, ElementCount::getFixed(1));
26770+
if (!isTypeLegal(NewVT))
26771+
NewVT = EltTy;
26772+
26773+
IntermediateVT = NewVT;
26774+
NumIntermediates = VT.getVectorNumElements();
26775+
RegisterVT = getRegisterType(Context, NewVT);
26776+
return NumIntermediates;
26777+
}
26778+
26779+
// SVE VLS support does not introduce a new ABI so we should use NEON sized
26780+
// types for vector arguments and returns.
26781+
26782+
unsigned NumSubRegs = RegisterVT.getFixedSizeInBits() / 128;
26783+
NumIntermediates *= NumSubRegs;
26784+
NumRegs *= NumSubRegs;
26785+
26786+
switch (RegisterVT.getVectorElementType().SimpleTy) {
26787+
default:
26788+
llvm_unreachable("unexpected element type for vector");
26789+
case MVT::i8:
26790+
IntermediateVT = RegisterVT = MVT::v16i8;
26791+
break;
26792+
case MVT::i16:
26793+
IntermediateVT = RegisterVT = MVT::v8i16;
26794+
break;
26795+
case MVT::i32:
26796+
IntermediateVT = RegisterVT = MVT::v4i32;
26797+
break;
26798+
case MVT::i64:
26799+
IntermediateVT = RegisterVT = MVT::v2i64;
26800+
break;
26801+
case MVT::f16:
26802+
IntermediateVT = RegisterVT = MVT::v8f16;
26803+
break;
26804+
case MVT::f32:
26805+
IntermediateVT = RegisterVT = MVT::v4f32;
26806+
break;
26807+
case MVT::f64:
26808+
IntermediateVT = RegisterVT = MVT::v2f64;
26809+
break;
26810+
case MVT::bf16:
26811+
IntermediateVT = RegisterVT = MVT::v8bf16;
26812+
break;
26813+
}
26814+
26815+
return NumRegs;
26816+
}

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,18 @@ class AArch64TargetLowering : public TargetLowering {
954954
// used for 64bit and 128bit vectors as well.
955955
bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
956956

957+
// Follow NEON ABI rules even when using SVE for fixed length vectors.
958+
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
959+
EVT VT) const override;
960+
unsigned getNumRegistersForCallingConv(LLVMContext &Context,
961+
CallingConv::ID CC,
962+
EVT VT) const override;
963+
unsigned getVectorTypeBreakdownForCallingConv(LLVMContext &Context,
964+
CallingConv::ID CC, EVT VT,
965+
EVT &IntermediateVT,
966+
unsigned &NumIntermediates,
967+
MVT &RegisterVT) const override;
968+
957969
private:
958970
/// Keep a pointer to the AArch64Subtarget around so that we can
959971
/// make the right decision when generating code for different targets.

llvm/test/CodeGen/AArch64/sve-fixed-length-function-calls.ll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
22
; RUN: llc < %s | FileCheck %s
3+
; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s
4+
; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s
35

46
target triple = "aarch64-unknown-linux-gnu"
57

llvm/test/CodeGen/AArch64/sve-fixed-length-functions.ll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
22
; RUN: llc < %s | FileCheck %s
3+
; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s
4+
; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s
35

46
target triple = "aarch64-unknown-linux-gnu"
57

0 commit comments

Comments
 (0)