Skip to content

Commit 70aeb89

Browse files
authored
Calculate KnownBits from Metadata correctly for vector loads (#128908)
Calculate KnownBits correctly from metadata for vector loads. --------- Signed-off-by: John Lu <[email protected]>
1 parent 23bf98e commit 70aeb89

File tree

3 files changed

+132
-28
lines changed

3 files changed

+132
-28
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,14 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) {
690690
assert(NVT.getSizeInBits() == VT.getSizeInBits() &&
691691
"Can only promote loads to same size type");
692692

693+
// If the range metadata type does not match the legalized memory
694+
// operation type, remove the range metadata.
695+
if (const MDNode *MD = LD->getRanges()) {
696+
ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
697+
if (Lower->getBitWidth() != NVT.getScalarSizeInBits() ||
698+
!NVT.isInteger())
699+
LD->getMemOperand()->clearRanges();
700+
}
693701
SDValue Res = DAG.getLoad(NVT, dl, Chain, Ptr, LD->getMemOperand());
694702
RVal = DAG.getNode(ISD::BITCAST, dl, VT, Res);
695703
RChain = Res.getValue(1);

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4004,39 +4004,20 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
40044004
}
40054005
}
40064006
} else if (Op.getResNo() == 0) {
4007-
KnownBits Known0(!LD->getMemoryVT().isScalableVT()
4008-
? LD->getMemoryVT().getFixedSizeInBits()
4009-
: BitWidth);
4010-
EVT VT = Op.getValueType();
4011-
// Fill in any known bits from range information. There are 3 types being
4012-
// used. The results VT (same vector elt size as BitWidth), the loaded
4013-
// MemoryVT (which may or may not be vector) and the range VTs original
4014-
// type. The range matadata needs the full range (i.e
4015-
// MemoryVT().getSizeInBits()), which is truncated to the correct elt size
4016-
// if it is know. These are then extended to the original VT sizes below.
4017-
if (const MDNode *MD = LD->getRanges()) {
4018-
computeKnownBitsFromRangeMetadata(*MD, Known0);
4019-
if (VT.isVector()) {
4020-
// Handle truncation to the first demanded element.
4021-
// TODO: Figure out which demanded elements are covered
4022-
if (DemandedElts != 1 || !getDataLayout().isLittleEndian())
4023-
break;
4024-
Known0 = Known0.trunc(BitWidth);
4025-
}
4026-
}
4027-
4028-
if (LD->getMemoryVT().isVector())
4029-
Known0 = Known0.trunc(LD->getMemoryVT().getScalarSizeInBits());
4007+
unsigned ScalarMemorySize = LD->getMemoryVT().getScalarSizeInBits();
4008+
KnownBits KnownScalarMemory(ScalarMemorySize);
4009+
if (const MDNode *MD = LD->getRanges())
4010+
computeKnownBitsFromRangeMetadata(*MD, KnownScalarMemory);
40304011

4031-
// Extend the Known bits from memory to the size of the result.
4012+
// Extend the Known bits from memory to the size of the scalar result.
40324013
if (ISD::isZEXTLoad(Op.getNode()))
4033-
Known = Known0.zext(BitWidth);
4014+
Known = KnownScalarMemory.zext(BitWidth);
40344015
else if (ISD::isSEXTLoad(Op.getNode()))
4035-
Known = Known0.sext(BitWidth);
4016+
Known = KnownScalarMemory.sext(BitWidth);
40364017
else if (ISD::isEXTLoad(Op.getNode()))
4037-
Known = Known0.anyext(BitWidth);
4018+
Known = KnownScalarMemory.anyext(BitWidth);
40384019
else
4039-
Known = Known0;
4020+
Known = KnownScalarMemory;
40404021
assert(Known.getBitWidth() == BitWidth);
40414022
return Known;
40424023
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; Ensure that range metadata is handled correctly for vector loads.
3+
; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 < %s | FileCheck %s
4+
5+
define <2 x i16> @test_add2x16(ptr %a_ptr, ptr %b_ptr) {
6+
; CHECK-LABEL: test_add2x16:
7+
; CHECK: ; %bb.0:
8+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
9+
; CHECK-NEXT: v_mov_b32_e32 v0, 0x300030
10+
; CHECK-NEXT: s_setpc_b64 s[30:31]
11+
%a = load <2 x i16>, ptr %a_ptr, !range !0, !noundef !{}
12+
%b = load <2 x i16>, ptr %b_ptr, !range !1, !noundef !{}
13+
%result = add <2 x i16> %a, %b
14+
ret <2 x i16> %result
15+
}
16+
17+
define <2 x i32> @test_add2x32(ptr %a_ptr, ptr %b_ptr) {
18+
; CHECK-LABEL: test_add2x32:
19+
; CHECK: ; %bb.0:
20+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
21+
; CHECK-NEXT: flat_load_dword v4, v[2:3]
22+
; CHECK-NEXT: flat_load_dword v5, v[0:1]
23+
; CHECK-NEXT: v_mov_b32_e32 v1, 48
24+
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
25+
; CHECK-NEXT: v_or_b32_e32 v0, v5, v4
26+
; CHECK-NEXT: s_setpc_b64 s[30:31]
27+
%a = load <2 x i32>, ptr %a_ptr, !range !2, !noundef !{}
28+
%b = load <2 x i32>, ptr %b_ptr, !range !3, !noundef !{}
29+
%result = add <2 x i32> %a, %b
30+
ret <2 x i32> %result
31+
}
32+
33+
define <2 x i64> @test_add2x64(ptr %a_ptr, ptr %b_ptr) {
34+
; CHECK-LABEL: test_add2x64:
35+
; CHECK: ; %bb.0:
36+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
37+
; CHECK-NEXT: flat_load_dwordx4 v[4:7], v[0:1]
38+
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
39+
; CHECK-NEXT: flat_load_dwordx4 v[6:9], v[2:3]
40+
; CHECK-NEXT: ; kill: killed $vgpr2 killed $vgpr3
41+
; CHECK-NEXT: ; kill: killed $vgpr0 killed $vgpr1
42+
; CHECK-NEXT: v_mov_b32_e32 v2, 48
43+
; CHECK-NEXT: v_mov_b32_e32 v3, 0
44+
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
45+
; CHECK-NEXT: v_or_b32_e32 v1, v5, v7
46+
; CHECK-NEXT: v_or_b32_e32 v0, v4, v6
47+
; CHECK-NEXT: s_setpc_b64 s[30:31]
48+
%a = load <2 x i64>, ptr %a_ptr, !range !4, !noundef !{}
49+
%b = load <2 x i64>, ptr %b_ptr, !range !5, !noundef !{}
50+
%result = add <2 x i64> %a, %b
51+
ret <2 x i64> %result
52+
}
53+
54+
define <3 x i16> @test_add3x16(ptr %a_ptr, ptr %b_ptr) {
55+
; CHECK-LABEL: test_add3x16:
56+
; CHECK: ; %bb.0:
57+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
58+
; CHECK-NEXT: flat_load_dwordx2 v[4:5], v[0:1]
59+
; CHECK-NEXT: flat_load_dwordx2 v[6:7], v[2:3]
60+
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
61+
; CHECK-NEXT: v_or_b32_e32 v1, v5, v7
62+
; CHECK-NEXT: v_or_b32_e32 v0, v4, v6
63+
; CHECK-NEXT: s_setpc_b64 s[30:31]
64+
%a = load <3 x i16>, ptr %a_ptr, !range !0, !noundef !{}
65+
%b = load <3 x i16>, ptr %b_ptr, !range !1, !noundef !{}
66+
%result = add <3 x i16> %a, %b
67+
ret <3 x i16> %result
68+
}
69+
70+
define <3 x i32> @test_add3x32(ptr %a_ptr, ptr %b_ptr) {
71+
; CHECK-LABEL: test_add3x32:
72+
; CHECK: ; %bb.0:
73+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
74+
; CHECK-NEXT: flat_load_dword v4, v[2:3]
75+
; CHECK-NEXT: flat_load_dword v5, v[0:1]
76+
; CHECK-NEXT: v_mov_b32_e32 v1, 48
77+
; CHECK-NEXT: v_mov_b32_e32 v2, 48
78+
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
79+
; CHECK-NEXT: v_or_b32_e32 v0, v5, v4
80+
; CHECK-NEXT: s_setpc_b64 s[30:31]
81+
%a = load <3 x i32>, ptr %a_ptr, !range !2, !noundef !{}
82+
%b = load <3 x i32>, ptr %b_ptr, !range !3, !noundef !{}
83+
%result = add <3 x i32> %a, %b
84+
ret <3 x i32> %result
85+
}
86+
87+
define <3 x i64> @test_add3x64(ptr %a_ptr, ptr %b_ptr) {
88+
; CHECK-LABEL: test_add3x64:
89+
; CHECK: ; %bb.0:
90+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
91+
; CHECK-NEXT: flat_load_dwordx4 v[4:7], v[0:1]
92+
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
93+
; CHECK-NEXT: flat_load_dwordx4 v[6:9], v[2:3]
94+
; CHECK-NEXT: ; kill: killed $vgpr2 killed $vgpr3
95+
; CHECK-NEXT: ; kill: killed $vgpr0 killed $vgpr1
96+
; CHECK-NEXT: v_mov_b32_e32 v2, 48
97+
; CHECK-NEXT: v_mov_b32_e32 v3, 0
98+
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
99+
; CHECK-NEXT: v_or_b32_e32 v1, v5, v7
100+
; CHECK-NEXT: v_or_b32_e32 v0, v4, v6
101+
; CHECK-NEXT: v_mov_b32_e32 v4, 48
102+
; CHECK-NEXT: v_mov_b32_e32 v5, 0
103+
; CHECK-NEXT: s_setpc_b64 s[30:31]
104+
%a = load <3 x i64>, ptr %a_ptr, !range !4, !noundef !{}
105+
%b = load <3 x i64>, ptr %b_ptr, !range !5, !noundef !{}
106+
%result = add <3 x i64> %a, %b
107+
ret <3 x i64> %result
108+
}
109+
110+
!0 = !{i16 16, i16 17 }
111+
!1 = !{i16 32, i16 33 }
112+
!2 = !{i32 16, i32 17 }
113+
!3 = !{i32 32, i32 33 }
114+
!4 = !{i64 16, i64 17 }
115+
!5 = !{i64 32, i64 33 }

0 commit comments

Comments
 (0)