Skip to content

[SelectionDAG][RISCV] Avoid store merging across function calls #130430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Mar 22, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3e7e708
[SelectionDAG] Avoid store merging across function calls
mikhailramalho Feb 28, 2025
5b1fc65
Check call_end instead of start
mikhailramalho Mar 4, 2025
be40ec2
Only walk over mem operation chains
mikhailramalho Mar 4, 2025
c3298ff
Don't go over NumConsecutiveStores
mikhailramalho Mar 4, 2025
b56d1b1
Use SDValues
mikhailramalho Mar 4, 2025
82420c7
Added fallthrough
mikhailramalho Mar 4, 2025
96c8e53
Add Visited list to cache the walk
mikhailramalho Mar 5, 2025
3574370
Moved increment
mikhailramalho Mar 5, 2025
f9393d5
Updated test case
mikhailramalho Mar 5, 2025
d86ec01
Enable merge by default for scalars
mikhailramalho Mar 5, 2025
04bca6d
Rewrite walk back algo to keep track of calls found
mikhailramalho Mar 5, 2025
f27092f
Check final type before we prevent merges
mikhailramalho Mar 6, 2025
9faa629
No need to check operands. It's checked in the start of the loop
mikhailramalho Mar 17, 2025
b326da1
Assert operand type
mikhailramalho Mar 17, 2025
c858020
Moved peekThroughBitcasts into an assertion
mikhailramalho Mar 17, 2025
b6b1521
Use getChain instead of accessing the operand 0
mikhailramalho Mar 17, 2025
18e68ea
Make hasCallInLdStChain a member function
mikhailramalho Mar 18, 2025
3bc2b22
Added test case
mikhailramalho Mar 18, 2025
816a235
Merge remote-tracking branch 'origin/main' into dag-spillcost-fix
mikhailramalho Mar 19, 2025
904641f
Removed duplicated test after merge
mikhailramalho Mar 19, 2025
a255f16
No need to declare intrinsics anymore
mikhailramalho Mar 19, 2025
75f4caa
Removed unused args
mikhailramalho Mar 19, 2025
de96633
Address comment
mikhailramalho Mar 19, 2025
69c361c
Address comment
mikhailramalho Mar 19, 2025
0dfd354
Removed todo
mikhailramalho Mar 19, 2025
e73c49d
Simplify interface
mikhailramalho Mar 19, 2025
a88e73b
Merge remote-tracking branch 'origin/main' into dag-spillcost-fix
mikhailramalho Mar 19, 2025
d6c848d
Remove assert that fails when building blender_r
mikhailramalho Mar 20, 2025
0189f30
Address comment
mikhailramalho Mar 20, 2025
99e11ae
Merge remote-tracking branch 'origin/main' into dag-spillcost-fix
mikhailramalho Mar 20, 2025
67b3b65
Update test
mikhailramalho Mar 21, 2025
ed8a5fd
Removed todo
mikhailramalho Mar 21, 2025
21ba7b2
Merge remote-tracking branch 'origin/main' into dag-spillcost-fix
mikhailramalho Mar 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3520,6 +3520,10 @@ class TargetLoweringBase {
/// The default implementation just freezes the set of reserved registers.
virtual void finalizeLowering(MachineFunction &MF) const;

/// Returns true if it's profitable to allow merging store of loads when there
/// are functions calls between the load and the store.
virtual bool shouldMergeStoreOfLoadsOverCall(EVT, EVT) const { return true; }

//===----------------------------------------------------------------------===//
// GlobalISel Hooks
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,10 @@ namespace {
SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
SDNode *RootNode);

/// Helper function for tryStoreMergeOfLoads. Checks if the load/store
/// chain has a call in it. \return True if a call is found.
bool hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld);

/// This is a helper function for mergeConsecutiveStores. Given a list of
/// store candidates, find the first N that are consecutive in memory.
/// Returns 0 if there are not at least 2 consecutive stores to try merging.
Expand Down Expand Up @@ -21152,6 +21156,41 @@ bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
return true;
}

bool DAGCombiner::hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld) {
assert(Ld == cast<LoadSDNode>(peekThroughBitcasts(St->getValue())) &&
"Load and store mismatch");

SmallPtrSet<const SDNode *, 32> Visited;
SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
Worklist.emplace_back(St->getChain().getNode(), false);

while (!Worklist.empty()) {
auto [Node, FoundCall] = Worklist.pop_back_val();
if (!Visited.insert(Node).second || Node->getNumOperands() == 0)
continue;

switch (Node->getOpcode()) {
case ISD::CALLSEQ_END:
Worklist.emplace_back(Node->getOperand(0).getNode(), true);
break;
case ISD::TokenFactor:
for (SDValue Op : Node->ops())
Worklist.emplace_back(Op.getNode(), FoundCall);
break;
case ISD::LOAD:
if (Node == Ld)
return FoundCall;
[[fallthrough]];
default:
assert(Node->getOperand(0).getValueType() == MVT::Other &&
"Invalid chain type");
Worklist.emplace_back(Node->getOperand(0).getNode(), FoundCall);
break;
}
}
return false;
}

unsigned
DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
int64_t ElementSizeBytes) const {
Expand Down Expand Up @@ -21598,6 +21637,16 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
}

// Check if there is a call in the load/store chain.
if (!TLI.shouldMergeStoreOfLoadsOverCall(MemVT, JointMemOpVT) &&
hasCallInLdStChain(cast<StoreSDNode>(StoreNodes[0].MemNode),
cast<LoadSDNode>(LoadNodes[0].MemNode))) {
StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
NumConsecutiveStores -= NumElem;
continue;
}

SDLoc LoadDL(LoadNodes[0].MemNode);
SDLoc StoreDL(StoreNodes[0].MemNode);

Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,13 @@ class RISCVTargetLowering : public TargetLowering {
return false;
}

/// Disables storing and loading vectors by default when there are function
/// calls between the load and store, since these are more expensive than just
/// using scalars
bool shouldMergeStoreOfLoadsOverCall(EVT SrcVT, EVT MergedVT) const override {
return SrcVT.isScalarInteger() == MergedVT.isScalarInteger();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about scalar FP?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SrcVT can be a scalar FP and MergeVT can be a scalar integer VT. In that case it would still be ok to merge across the call.

Maybe this should be

  return !MergedVT.isVector() || SrcVT.isVector();

Copy link
Collaborator

@preames preames Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I glanced at what we do for this today. Quick summary:
2 x half (no zvfh) -> no merge performed
2 x half (zvfh) -> <2 x half> vector result type
2 x float -> <2 x float> vector result type
2 x float + noimplicitfloat --> two integer loads, no merge

So basically, I think the patch as it is should be a net improvement (by disabling case 2 and 3). I'm not opposed to your suggest change, just pointing out it's not needed to avoid a regression. (Edit: Except on re-read, it doesn't do that. FP is non scalar and vector is non scalar, which is equal with the current check. Yeah, we should do Craig's suggestion here.)

Note that the SDAG code doesn't even try to find a wider legal floating point type to merge to. The merge type will only be integer or vector.

Copy link
Collaborator

@topperc topperc Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I glanced at what we do for this today. Quick summary:
2 x half (no zvfh) -> no merge performed
2 x half (zvfh) -> <2 x half> vector result type
2 x float -> <2 x float> vector result type
2 x float + noimplicitfloat --> two integer loads, no merge

@mikhailramalho said the case he hit the assert on was 2 x float -> i64.

Note that the SDAG code doesn't even try to find a wider legal floating point type to merge to. The merge type will only be integer or vector.

A wider legal FP doesn't really make sense. You can't really join them without going through integer.

I think either !MergedVT.isVector() || SrcVT.isVector() or MergedVT.isVector() == SrcVT.isVector() would cover vector without picking up scalar FP in scalar integer. And allow the assert that was removed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test cases (aside from the noimplicitfloat one) added in a5c7f81

}

/// For available scheduling models FDIV + two independent FMULs are much
/// faster than two FDIVs.
unsigned combineRepeatedFPDivisors() const override;
Expand Down
34 changes: 15 additions & 19 deletions llvm/test/CodeGen/RISCV/stores-of-loads-merging.ll
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s

declare i32 @llvm.experimental.constrained.fptosi.i32.f64(double, metadata)
declare void @g()

; TODO: Merging scalars into vectors is unprofitable because we have no
; vector CSRs which creates additional spills around the call.
define void @f(ptr %m, ptr %n, ptr %p, ptr %q, ptr %r, ptr %s, double %t) {
; CHECK-LABEL: f:
; CHECK: # %bb.0:
Expand All @@ -15,40 +12,40 @@ define void @f(ptr %m, ptr %n, ptr %p, ptr %q, ptr %r, ptr %s, double %t) {
; CHECK-NEXT: sd s0, 32(sp) # 8-byte Folded Spill
; CHECK-NEXT: sd s1, 24(sp) # 8-byte Folded Spill
; CHECK-NEXT: sd s2, 16(sp) # 8-byte Folded Spill
; CHECK-NEXT: sd s3, 8(sp) # 8-byte Folded Spill
; CHECK-NEXT: sd s4, 0(sp) # 8-byte Folded Spill
; CHECK-NEXT: .cfi_offset ra, -8
; CHECK-NEXT: .cfi_offset s0, -16
; CHECK-NEXT: .cfi_offset s1, -24
; CHECK-NEXT: .cfi_offset s2, -32
; CHECK-NEXT: csrr a6, vlenb
; CHECK-NEXT: sub sp, sp, a6
; CHECK-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x30, 0x22, 0x11, 0x01, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 48 + 1 * vlenb
; CHECK-NEXT: .cfi_offset s3, -40
; CHECK-NEXT: .cfi_offset s4, -48
; CHECK-NEXT: mv s0, a5
; CHECK-NEXT: mv s1, a4
; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma
; CHECK-NEXT: vle64.v v8, (a0)
; CHECK-NEXT: vse64.v v8, (a1)
; CHECK-NEXT: vle64.v v8, (a2)
; CHECK-NEXT: addi a0, sp, 16
; CHECK-NEXT: vs1r.v v8, (a0) # Unknown-size Folded Spill
; CHECK-NEXT: ld s3, 0(a2)
; CHECK-NEXT: ld s4, 8(a2)
; CHECK-NEXT: mv s2, a3
; CHECK-NEXT: call g
; CHECK-NEXT: addi a0, sp, 16
; CHECK-NEXT: vl1r.v v8, (a0) # Unknown-size Folded Reload
; CHECK-NEXT: sd s3, 0(s2)
; CHECK-NEXT: sd s4, 8(s2)
; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma
; CHECK-NEXT: vse64.v v8, (s2)
; CHECK-NEXT: vle64.v v8, (s1)
; CHECK-NEXT: vse64.v v8, (s0)
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: add sp, sp, a0
; CHECK-NEXT: .cfi_def_cfa sp, 48
; CHECK-NEXT: ld ra, 40(sp) # 8-byte Folded Reload
; CHECK-NEXT: ld s0, 32(sp) # 8-byte Folded Reload
; CHECK-NEXT: ld s1, 24(sp) # 8-byte Folded Reload
; CHECK-NEXT: ld s2, 16(sp) # 8-byte Folded Reload
; CHECK-NEXT: ld s3, 8(sp) # 8-byte Folded Reload
; CHECK-NEXT: ld s4, 0(sp) # 8-byte Folded Reload
; CHECK-NEXT: .cfi_restore ra
; CHECK-NEXT: .cfi_restore s0
; CHECK-NEXT: .cfi_restore s1
; CHECK-NEXT: .cfi_restore s2
; CHECK-NEXT: .cfi_restore s3
; CHECK-NEXT: .cfi_restore s4
; CHECK-NEXT: addi sp, sp, 48
; CHECK-NEXT: .cfi_def_cfa_offset 0
; CHECK-NEXT: ret
Expand Down Expand Up @@ -77,13 +74,13 @@ define void @f(ptr %m, ptr %n, ptr %p, ptr %q, ptr %r, ptr %s, double %t) {
ret void
}

define void @f1(ptr %m, ptr %n, ptr %p, ptr %q, ptr %r, ptr %s, double %t) {
define void @f1(ptr %p, ptr %q, double %t) {
; CHECK-LABEL: f1:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma
; CHECK-NEXT: vle64.v v8, (a2)
; CHECK-NEXT: vle64.v v8, (a0)
; CHECK-NEXT: fcvt.wu.d a0, fa0, rtz
; CHECK-NEXT: vse64.v v8, (a3)
; CHECK-NEXT: vse64.v v8, (a1)
; CHECK-NEXT: ret
%x0 = load i64, ptr %p
%p.1 = getelementptr i64, ptr %p, i64 1
Expand All @@ -92,7 +89,6 @@ define void @f1(ptr %m, ptr %n, ptr %p, ptr %q, ptr %r, ptr %s, double %t) {
store i64 %x0, ptr %q
%q.1 = getelementptr i64, ptr %q, i64 1
store i64 %x1, ptr %q.1

ret void
}

Expand Down