Skip to content

Commit f32ebab

Browse files
authored
[NVPTX] Improve folding to mad with immediate 1 (#93628)
Extend NVPTX DAG combining logic to distribute a mul instruction across an add of 1 into a mad where possible. In addition, add support for transposing a mul through a select with an option of 1, if that would allow further mul folding.
1 parent 10436ae commit f32ebab

File tree

2 files changed

+228
-6
lines changed

2 files changed

+228
-6
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5614,17 +5614,103 @@ static SDValue TryMULWIDECombine(SDNode *N,
56145614
return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
56155615
}
56165616

5617+
static bool isConstOne(const SDValue &Operand) {
5618+
const auto *Const = dyn_cast<ConstantSDNode>(Operand);
5619+
return Const && Const->getZExtValue() == 1;
5620+
}
5621+
5622+
static SDValue matchMADConstOnePattern(SDValue Add) {
5623+
if (Add->getOpcode() != ISD::ADD)
5624+
return SDValue();
5625+
5626+
if (isConstOne(Add->getOperand(0)))
5627+
return Add->getOperand(1);
5628+
5629+
if (isConstOne(Add->getOperand(1)))
5630+
return Add->getOperand(0);
5631+
5632+
return SDValue();
5633+
}
5634+
5635+
static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
5636+
TargetLowering::DAGCombinerInfo &DCI) {
5637+
5638+
if (SDValue Y = matchMADConstOnePattern(Add))
5639+
return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);
5640+
5641+
return SDValue();
5642+
}
5643+
5644+
static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
5645+
SDLoc DL,
5646+
TargetLowering::DAGCombinerInfo &DCI) {
5647+
if (Select->getOpcode() != ISD::SELECT)
5648+
return SDValue();
5649+
5650+
SDValue Cond = Select->getOperand(0);
5651+
5652+
unsigned ConstOpNo;
5653+
if (isConstOne(Select->getOperand(1)))
5654+
ConstOpNo = 1;
5655+
else if (isConstOne(Select->getOperand(2)))
5656+
ConstOpNo = 2;
5657+
else
5658+
return SDValue();
5659+
5660+
SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
5661+
5662+
// Do not combine if the resulting sequence is not obviously profitable.
5663+
if (!matchMADConstOnePattern(Y))
5664+
return SDValue();
5665+
5666+
SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
5667+
5668+
return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
5669+
(ConstOpNo == 1) ? X : NewMul,
5670+
(ConstOpNo == 1) ? NewMul : X);
5671+
}
5672+
5673+
static SDValue
5674+
PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5675+
TargetLowering::DAGCombinerInfo &DCI) {
5676+
5677+
EVT VT = N0.getValueType();
5678+
if (VT.isVector())
5679+
return SDValue();
5680+
5681+
if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
5682+
return SDValue();
5683+
5684+
SDLoc DL(N);
5685+
5686+
// (mul x, (add y, 1)) -> (mad x, y, x)
5687+
if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
5688+
return Res;
5689+
if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
5690+
return Res;
5691+
5692+
// (mul x, (select y, 1)) -> (select (mul x, y), x)
5693+
if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
5694+
return Res;
5695+
if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
5696+
return Res;
5697+
5698+
return SDValue();
5699+
}
5700+
56175701
/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
56185702
static SDValue PerformMULCombine(SDNode *N,
56195703
TargetLowering::DAGCombinerInfo &DCI,
56205704
CodeGenOptLevel OptLevel) {
5621-
if (OptLevel > CodeGenOptLevel::None) {
5622-
// Try mul.wide combining at OptLevel > 0
5623-
if (SDValue Ret = TryMULWIDECombine(N, DCI))
5624-
return Ret;
5625-
}
5705+
if (OptLevel == CodeGenOptLevel::None)
5706+
return SDValue();
56265707

5627-
return SDValue();
5708+
if (SDValue Ret = TryMULWIDECombine(N, DCI))
5709+
return Ret;
5710+
5711+
SDValue N0 = N->getOperand(0);
5712+
SDValue N1 = N->getOperand(1);
5713+
return PerformMULCombineWithOperands(N, N0, N1, DCI);
56285714
}
56295715

56305716
/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 | FileCheck %s
3+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 | FileCheck %s
4+
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 | %ptxas-verify %}
5+
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 | %ptxas-verify %}
6+
7+
define i32 @test1(i32 %n, i32 %m) {
8+
;
9+
; CHECK-LABEL: test1(
10+
; CHECK: {
11+
; CHECK-NEXT: .reg .b32 %r<4>;
12+
; CHECK-EMPTY:
13+
; CHECK-NEXT: // %bb.0:
14+
; CHECK-NEXT: ld.param.u32 %r1, [test1_param_0];
15+
; CHECK-NEXT: ld.param.u32 %r2, [test1_param_1];
16+
; CHECK-NEXT: mad.lo.s32 %r3, %r2, %r1, %r2;
17+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
18+
; CHECK-NEXT: ret;
19+
%add = add i32 %n, 1
20+
%mul = mul i32 %add, %m
21+
ret i32 %mul
22+
}
23+
24+
define i32 @test1_rev(i32 %n, i32 %m) {
25+
;
26+
; CHECK-LABEL: test1_rev(
27+
; CHECK: {
28+
; CHECK-NEXT: .reg .b32 %r<4>;
29+
; CHECK-EMPTY:
30+
; CHECK-NEXT: // %bb.0:
31+
; CHECK-NEXT: ld.param.u32 %r1, [test1_rev_param_0];
32+
; CHECK-NEXT: ld.param.u32 %r2, [test1_rev_param_1];
33+
; CHECK-NEXT: mad.lo.s32 %r3, %r2, %r1, %r2;
34+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
35+
; CHECK-NEXT: ret;
36+
%add = add i32 %n, 1
37+
%mul = mul i32 %m, %add
38+
ret i32 %mul
39+
}
40+
41+
; Transpose (mul (select)) if it can then be folded to mad
42+
define i32 @test2(i32 %n, i32 %m, i32 %s) {
43+
;
44+
; CHECK-LABEL: test2(
45+
; CHECK: {
46+
; CHECK-NEXT: .reg .pred %p<2>;
47+
; CHECK-NEXT: .reg .b32 %r<6>;
48+
; CHECK-EMPTY:
49+
; CHECK-NEXT: // %bb.0:
50+
; CHECK-NEXT: ld.param.u32 %r1, [test2_param_0];
51+
; CHECK-NEXT: ld.param.u32 %r2, [test2_param_1];
52+
; CHECK-NEXT: ld.param.u32 %r3, [test2_param_2];
53+
; CHECK-NEXT: setp.lt.s32 %p1, %r3, 1;
54+
; CHECK-NEXT: mad.lo.s32 %r4, %r2, %r1, %r2;
55+
; CHECK-NEXT: selp.b32 %r5, %r2, %r4, %p1;
56+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
57+
; CHECK-NEXT: ret;
58+
%add = add i32 %n, 1
59+
%cond = icmp slt i32 %s, 1
60+
%sel = select i1 %cond, i32 1, i32 %add
61+
%mul = mul i32 %sel, %m
62+
ret i32 %mul
63+
}
64+
65+
;; Transpose (mul (select)) if it can then be folded to mad
66+
define i32 @test2_rev1(i32 %n, i32 %m, i32 %s) {
67+
;
68+
; CHECK-LABEL: test2_rev1(
69+
; CHECK: {
70+
; CHECK-NEXT: .reg .pred %p<2>;
71+
; CHECK-NEXT: .reg .b32 %r<6>;
72+
; CHECK-EMPTY:
73+
; CHECK-NEXT: // %bb.0:
74+
; CHECK-NEXT: ld.param.u32 %r1, [test2_rev1_param_0];
75+
; CHECK-NEXT: ld.param.u32 %r2, [test2_rev1_param_1];
76+
; CHECK-NEXT: ld.param.u32 %r3, [test2_rev1_param_2];
77+
; CHECK-NEXT: setp.lt.s32 %p1, %r3, 1;
78+
; CHECK-NEXT: mad.lo.s32 %r4, %r2, %r1, %r2;
79+
; CHECK-NEXT: selp.b32 %r5, %r4, %r2, %p1;
80+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
81+
; CHECK-NEXT: ret;
82+
%add = add i32 %n, 1
83+
%cond = icmp slt i32 %s, 1
84+
%sel = select i1 %cond, i32 %add, i32 1
85+
%mul = mul i32 %sel, %m
86+
ret i32 %mul
87+
}
88+
89+
;; Transpose (mul (select)) if it can then be folded to mad
90+
define i32 @test2_rev2(i32 %n, i32 %m, i32 %s) {
91+
;
92+
; CHECK-LABEL: test2_rev2(
93+
; CHECK: {
94+
; CHECK-NEXT: .reg .pred %p<2>;
95+
; CHECK-NEXT: .reg .b32 %r<6>;
96+
; CHECK-EMPTY:
97+
; CHECK-NEXT: // %bb.0:
98+
; CHECK-NEXT: ld.param.u32 %r1, [test2_rev2_param_0];
99+
; CHECK-NEXT: ld.param.u32 %r2, [test2_rev2_param_1];
100+
; CHECK-NEXT: ld.param.u32 %r3, [test2_rev2_param_2];
101+
; CHECK-NEXT: setp.lt.s32 %p1, %r3, 1;
102+
; CHECK-NEXT: mad.lo.s32 %r4, %r2, %r1, %r2;
103+
; CHECK-NEXT: selp.b32 %r5, %r4, %r2, %p1;
104+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
105+
; CHECK-NEXT: ret;
106+
%add = add i32 %n, 1
107+
%cond = icmp slt i32 %s, 1
108+
%sel = select i1 %cond, i32 %add, i32 1
109+
%mul = mul i32 %m, %sel
110+
ret i32 %mul
111+
}
112+
113+
;; Leave (mul (select)) intact if it transposing is not profitable
114+
define i32 @test3(i32 %n, i32 %m, i32 %s) {
115+
;
116+
; CHECK-LABEL: test3(
117+
; CHECK: {
118+
; CHECK-NEXT: .reg .pred %p<2>;
119+
; CHECK-NEXT: .reg .b32 %r<7>;
120+
; CHECK-EMPTY:
121+
; CHECK-NEXT: // %bb.0:
122+
; CHECK-NEXT: ld.param.u32 %r1, [test3_param_0];
123+
; CHECK-NEXT: add.s32 %r2, %r1, 3;
124+
; CHECK-NEXT: ld.param.u32 %r3, [test3_param_1];
125+
; CHECK-NEXT: ld.param.u32 %r4, [test3_param_2];
126+
; CHECK-NEXT: setp.lt.s32 %p1, %r4, 1;
127+
; CHECK-NEXT: selp.b32 %r5, 1, %r2, %p1;
128+
; CHECK-NEXT: mul.lo.s32 %r6, %r5, %r3;
129+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r6;
130+
; CHECK-NEXT: ret;
131+
%add = add i32 %n, 3
132+
%cond = icmp slt i32 %s, 1
133+
%sel = select i1 %cond, i32 1, i32 %add
134+
%mul = mul i32 %sel, %m
135+
ret i32 %mul
136+
}

0 commit comments

Comments
 (0)