Skip to content

Commit 528943f

Browse files
author
Dinar Temirbulatov
authored
[AArch64][SME] Allow memory operations lowering to custom SME functions. (#79263)
This change allows to lower memcpy, memset, memmove to custom SME version provided by LibRT.
1 parent 5601e35 commit 528943f

File tree

4 files changed

+380
-4
lines changed

4 files changed

+380
-4
lines changed

llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ using namespace llvm;
1515

1616
#define DEBUG_TYPE "aarch64-selectiondag-info"
1717

18+
static cl::opt<bool>
19+
LowerToSMERoutines("aarch64-lower-to-sme-routines", cl::Hidden,
20+
cl::desc("Enable AArch64 SME memory operations "
21+
"to lower to librt functions"),
22+
cl::init(true));
23+
1824
SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
1925
SelectionDAG &DAG, const SDLoc &DL,
2026
SDValue Chain, SDValue Dst,
@@ -76,15 +82,79 @@ SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
7682
}
7783
}
7884

85+
SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall(
86+
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
87+
SDValue Size, RTLIB::Libcall LC) const {
88+
const AArch64Subtarget &STI =
89+
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
90+
const AArch64TargetLowering *TLI = STI.getTargetLowering();
91+
SDValue Symbol;
92+
TargetLowering::ArgListEntry DstEntry;
93+
DstEntry.Ty = PointerType::getUnqual(*DAG.getContext());
94+
DstEntry.Node = Dst;
95+
TargetLowering::ArgListTy Args;
96+
Args.push_back(DstEntry);
97+
EVT PointerVT = TLI->getPointerTy(DAG.getDataLayout());
98+
99+
switch (LC) {
100+
case RTLIB::MEMCPY: {
101+
TargetLowering::ArgListEntry Entry;
102+
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
103+
Symbol = DAG.getExternalSymbol("__arm_sc_memcpy", PointerVT);
104+
Entry.Node = Src;
105+
Args.push_back(Entry);
106+
break;
107+
}
108+
case RTLIB::MEMMOVE: {
109+
TargetLowering::ArgListEntry Entry;
110+
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
111+
Symbol = DAG.getExternalSymbol("__arm_sc_memmove", PointerVT);
112+
Entry.Node = Src;
113+
Args.push_back(Entry);
114+
break;
115+
}
116+
case RTLIB::MEMSET: {
117+
TargetLowering::ArgListEntry Entry;
118+
Entry.Ty = Type::getInt32Ty(*DAG.getContext());
119+
Symbol = DAG.getExternalSymbol("__arm_sc_memset", PointerVT);
120+
Src = DAG.getZExtOrTrunc(Src, DL, MVT::i32);
121+
Entry.Node = Src;
122+
Args.push_back(Entry);
123+
break;
124+
}
125+
default:
126+
return SDValue();
127+
}
128+
129+
TargetLowering::ArgListEntry SizeEntry;
130+
SizeEntry.Node = Size;
131+
SizeEntry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
132+
Args.push_back(SizeEntry);
133+
assert(Symbol->getOpcode() == ISD::ExternalSymbol &&
134+
"Function name is not set");
135+
136+
TargetLowering::CallLoweringInfo CLI(DAG);
137+
PointerType *RetTy = PointerType::getUnqual(*DAG.getContext());
138+
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
139+
TLI->getLibcallCallingConv(LC), RetTy, Symbol, std::move(Args));
140+
return TLI->LowerCallTo(CLI).second;
141+
}
142+
79143
SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
80144
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
81145
SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
82146
MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
83147
const AArch64Subtarget &STI =
84148
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
149+
85150
if (STI.hasMOPS())
86151
return EmitMOPS(AArch64ISD::MOPS_MEMCOPY, DAG, DL, Chain, Dst, Src, Size,
87152
Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
153+
154+
SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
155+
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
156+
return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
157+
RTLIB::MEMCPY);
88158
return SDValue();
89159
}
90160

@@ -95,10 +165,14 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
95165
const AArch64Subtarget &STI =
96166
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
97167

98-
if (STI.hasMOPS()) {
168+
if (STI.hasMOPS())
99169
return EmitMOPS(AArch64ISD::MOPS_MEMSET, DAG, dl, Chain, Dst, Src, Size,
100170
Alignment, isVolatile, DstPtrInfo, MachinePointerInfo{});
101-
}
171+
172+
SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
173+
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
174+
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
175+
RTLIB::MEMSET);
102176
return SDValue();
103177
}
104178

@@ -108,10 +182,15 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
108182
MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
109183
const AArch64Subtarget &STI =
110184
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
111-
if (STI.hasMOPS()) {
185+
186+
if (STI.hasMOPS())
112187
return EmitMOPS(AArch64ISD::MOPS_MEMMOVE, DAG, dl, Chain, Dst, Src, Size,
113188
Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
114-
}
189+
190+
SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
191+
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
192+
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
193+
RTLIB::MEMMOVE);
115194
return SDValue();
116195
}
117196

llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ class AArch64SelectionDAGInfo : public SelectionDAGTargetInfo {
4747
SDValue Chain, SDValue Op1, SDValue Op2,
4848
MachinePointerInfo DstPtrInfo,
4949
bool ZeroData) const override;
50+
51+
SDValue EmitStreamingCompatibleMemLibCall(SelectionDAG &DAG, const SDLoc &DL,
52+
SDValue Chain, SDValue Dst,
53+
SDValue Src, SDValue Size,
54+
RTLIB::Libcall LC) const;
5055
};
5156
}
5257

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
5353
if (FuncName == "__arm_tpidr2_restore")
5454
Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
5555
SMEAttrs::SME_ABI_Routine;
56+
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
57+
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
58+
Bitmask |= SMEAttrs::SM_Compatible;
5659
}
5760

5861
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {

0 commit comments

Comments
 (0)