Skip to content

Commit 9fd2e2c

Browse files
authored
[DAG][AArch64] Support masked loads/stores with nontemporal flags (#87608)
SVE has some non-temporal masked loads and stores. The metadata coming from the nodes is not copied to the MMO at the moment though, meaning it will generate a normal instruction. This patch ensures that the right flags are set if the instruction has non-temporal metadata.
1 parent ac321cb commit 9fd2e2c

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11824,8 +11824,8 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) {
1182411824
!MST->isCompressingStore() && !MST->isTruncatingStore())
1182511825
return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
1182611826
MST->getBasePtr(), MST->getPointerInfo(),
11827-
MST->getOriginalAlign(), MachineMemOperand::MOStore,
11828-
MST->getAAInfo());
11827+
MST->getOriginalAlign(),
11828+
MST->getMemOperand()->getFlags(), MST->getAAInfo());
1182911829

1183011830
// Try transforming N to an indexed store.
1183111831
if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
@@ -11962,7 +11962,7 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
1196211962
SDValue NewLd = DAG.getLoad(
1196311963
N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
1196411964
MLD->getPointerInfo(), MLD->getOriginalAlign(),
11965-
MachineMemOperand::MOLoad, MLD->getAAInfo(), MLD->getRanges());
11965+
MLD->getMemOperand()->getFlags(), MLD->getAAInfo(), MLD->getRanges());
1196611966
return CombineTo(N, NewLd, NewLd.getValue(1));
1196711967
}
1196811968

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4754,8 +4754,12 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
47544754

47554755
EVT VT = Src0.getValueType();
47564756

4757+
auto MMOFlags = MachineMemOperand::MOStore;
4758+
if (I.hasMetadata(LLVMContext::MD_nontemporal))
4759+
MMOFlags |= MachineMemOperand::MONonTemporal;
4760+
47574761
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
4758-
MachinePointerInfo(PtrOperand), MachineMemOperand::MOStore,
4762+
MachinePointerInfo(PtrOperand), MMOFlags,
47594763
LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());
47604764
SDValue StoreNode =
47614765
DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
@@ -4924,8 +4928,12 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
49244928

49254929
SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode();
49264930

4931+
auto MMOFlags = MachineMemOperand::MOLoad;
4932+
if (I.hasMetadata(LLVMContext::MD_nontemporal))
4933+
MMOFlags |= MachineMemOperand::MONonTemporal;
4934+
49274935
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
4928-
MachinePointerInfo(PtrOperand), MachineMemOperand::MOLoad,
4936+
MachinePointerInfo(PtrOperand), MMOFlags,
49294937
LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);
49304938

49314939
SDValue Load =

llvm/test/CodeGen/AArch64/sve-nontemporal-masked-ldst.ll

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ define <4 x i32> @masked_load_v4i32(ptr %a, <4 x i1> %mask) nounwind {
99
; CHECK-NEXT: shl v0.4s, v0.4s, #31
1010
; CHECK-NEXT: cmlt v0.4s, v0.4s, #0
1111
; CHECK-NEXT: cmpne p0.s, p0/z, z0.s, #0
12-
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
12+
; CHECK-NEXT: ldnt1w { z0.s }, p0/z, [x0]
1313
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
1414
; CHECK-NEXT: ret
1515
%load = call <4 x i32> @llvm.masked.load.v4i32(ptr %a, i32 1, <4 x i1> %mask, <4 x i32> undef), !nontemporal !0
@@ -25,7 +25,7 @@ define void @masked_store_v4i32(<4 x i32> %x, ptr %a, <4 x i1> %mask) nounwind {
2525
; CHECK-NEXT: shl v1.4s, v1.4s, #31
2626
; CHECK-NEXT: cmlt v1.4s, v1.4s, #0
2727
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
28-
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
28+
; CHECK-NEXT: stnt1w { z0.s }, p0, [x0]
2929
; CHECK-NEXT: ret
3030
call void @llvm.masked.store.v4i32.p0(<4 x i32> %x, ptr %a, i32 1, <4 x i1> %mask), !nontemporal !0
3131
ret void
@@ -43,7 +43,8 @@ define <4 x i32> @load_v4i32(ptr %a) nounwind {
4343
define void @store_v4i32(<4 x i32> %x, ptr %a) nounwind {
4444
; CHECK-LABEL: store_v4i32:
4545
; CHECK: // %bb.0:
46-
; CHECK-NEXT: str q0, [x0]
46+
; CHECK-NEXT: mov d1, v0.d[1]
47+
; CHECK-NEXT: stnp d0, d1, [x0]
4748
; CHECK-NEXT: ret
4849
call void @llvm.masked.store.v4i32.p0(<4 x i32> %x, ptr %a, i32 1, <4 x i1> <i1 1, i1 1, i1 1, i1 1>), !nontemporal !0
4950
ret void
@@ -52,7 +53,7 @@ define void @store_v4i32(<4 x i32> %x, ptr %a) nounwind {
5253
define <vscale x 4 x i32> @masked_load_nxv4i32(ptr %a, <vscale x 4 x i1> %mask) nounwind {
5354
; CHECK-LABEL: masked_load_nxv4i32:
5455
; CHECK: // %bb.0:
55-
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
56+
; CHECK-NEXT: ldnt1w { z0.s }, p0/z, [x0]
5657
; CHECK-NEXT: ret
5758
%load = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32(ptr %a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef), !nontemporal !0
5859
ret <vscale x 4 x i32> %load
@@ -61,7 +62,7 @@ define <vscale x 4 x i32> @masked_load_nxv4i32(ptr %a, <vscale x 4 x i1> %mask)
6162
define void @masked_store_nxv4i32(<vscale x 4 x i32> %x, ptr %a, <vscale x 4 x i1> %mask) nounwind {
6263
; CHECK-LABEL: masked_store_nxv4i32:
6364
; CHECK: // %bb.0:
64-
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
65+
; CHECK-NEXT: stnt1w { z0.s }, p0, [x0]
6566
; CHECK-NEXT: ret
6667
call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> %x, ptr %a, i32 1, <vscale x 4 x i1> %mask), !nontemporal !0
6768
ret void

0 commit comments

Comments
 (0)