Skip to content

Commit c11c2fe

Browse files
authored
[NVPTX] Lower i1 select with logical ops in the general case (#135868)
Update i1 select lowering to use an expansion based on logical ops, unless the selected operands are truncations. This can improve generated code quality by exposing additional potential optimizations.
1 parent e016a90 commit c11c2fe

File tree

5 files changed

+325
-177
lines changed

5 files changed

+325
-177
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2850,6 +2850,40 @@ static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG,
28502850
return DAG.getSelect(DL, Ty, IsInf, X, Sub);
28512851
}
28522852

2853+
static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
2854+
assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
2855+
2856+
SDValue Cond = Op->getOperand(0);
2857+
SDValue TrueVal = Op->getOperand(1);
2858+
SDValue FalseVal = Op->getOperand(2);
2859+
SDLoc DL(Op);
2860+
2861+
// If both operands are truncated, we push the select through the truncates.
2862+
if (TrueVal.getOpcode() == ISD::TRUNCATE &&
2863+
FalseVal.getOpcode() == ISD::TRUNCATE) {
2864+
TrueVal = TrueVal.getOperand(0);
2865+
FalseVal = FalseVal.getOperand(0);
2866+
2867+
EVT VT = TrueVal.getSimpleValueType().bitsLE(FalseVal.getSimpleValueType())
2868+
? TrueVal.getValueType()
2869+
: FalseVal.getValueType();
2870+
TrueVal = DAG.getAnyExtOrTrunc(TrueVal, DL, VT);
2871+
FalseVal = DAG.getAnyExtOrTrunc(FalseVal, DL, VT);
2872+
SDValue Select = DAG.getSelect(DL, VT, Cond, TrueVal, FalseVal);
2873+
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
2874+
}
2875+
2876+
// Otherwise, expand the select into a series of logical operations. These
2877+
// often can be folded into other operations either by us or ptxas.
2878+
TrueVal = DAG.getFreeze(TrueVal);
2879+
FalseVal = DAG.getFreeze(FalseVal);
2880+
SDValue And1 = DAG.getNode(ISD::AND, DL, MVT::i1, Cond, TrueVal);
2881+
SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
2882+
SDValue And2 = DAG.getNode(ISD::AND, DL, MVT::i1, NotCond, FalseVal);
2883+
SDValue Or = DAG.getNode(ISD::OR, DL, MVT::i1, And1, And2);
2884+
return Or;
2885+
}
2886+
28532887
SDValue
28542888
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28552889
switch (Op.getOpcode()) {
@@ -2889,7 +2923,7 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28892923
case ISD::SRL_PARTS:
28902924
return LowerShiftRightParts(Op, DAG);
28912925
case ISD::SELECT:
2892-
return LowerSelect(Op, DAG);
2926+
return lowerSELECT(Op, DAG);
28932927
case ISD::FROUND:
28942928
return LowerFROUND(Op, DAG);
28952929
case ISD::FCOPYSIGN:
@@ -3056,22 +3090,6 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
30563090
MachinePointerInfo(SV));
30573091
}
30583092

3059-
SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
3060-
SDValue Op0 = Op->getOperand(0);
3061-
SDValue Op1 = Op->getOperand(1);
3062-
SDValue Op2 = Op->getOperand(2);
3063-
SDLoc DL(Op.getNode());
3064-
3065-
assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
3066-
3067-
Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
3068-
Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
3069-
SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
3070-
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
3071-
3072-
return Trunc;
3073-
}
3074-
30753093
SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
30763094
if (Op.getValueType() == MVT::i1)
30773095
return LowerLOADi1(Op, DAG);

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,6 @@ class NVPTXTargetLowering : public TargetLowering {
324324
SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
325325
SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
326326

327-
SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
328-
329327
SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
330328

331329
SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/NVPTX/bug22246.ll

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
12
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s
23
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify %}
34

45
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
56
target triple = "nvptx64-nvidia-cuda"
67

7-
; CHECK-LABEL: _Z3foobbbPb
88
define void @_Z3foobbbPb(i1 zeroext %p1, i1 zeroext %p2, i1 zeroext %p3, ptr nocapture %output) {
9+
; CHECK-LABEL: _Z3foobbbPb(
10+
; CHECK: {
11+
; CHECK-NEXT: .reg .pred %p<2>;
12+
; CHECK-NEXT: .reg .b16 %rs<7>;
13+
; CHECK-NEXT: .reg .b64 %rd<2>;
14+
; CHECK-EMPTY:
15+
; CHECK-NEXT: // %bb.0: // %entry
16+
; CHECK-NEXT: ld.param.u8 %rs1, [_Z3foobbbPb_param_0];
17+
; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
18+
; CHECK-NEXT: setp.ne.b16 %p1, %rs2, 0;
19+
; CHECK-NEXT: ld.param.u8 %rs3, [_Z3foobbbPb_param_1];
20+
; CHECK-NEXT: ld.param.u8 %rs4, [_Z3foobbbPb_param_2];
21+
; CHECK-NEXT: selp.b16 %rs5, %rs3, %rs4, %p1;
22+
; CHECK-NEXT: and.b16 %rs6, %rs5, 1;
23+
; CHECK-NEXT: ld.param.u64 %rd1, [_Z3foobbbPb_param_3];
24+
; CHECK-NEXT: st.u8 [%rd1], %rs6;
25+
; CHECK-NEXT: ret;
926
entry:
10-
; CHECK: selp.b32 %r{{[0-9]+}}, %r{{[0-9]+}}, %r{{[0-9]+}}, %p{{[0-9]+}}
1127
%.sink.v = select i1 %p1, i1 %p2, i1 %p3
1228
%frombool5 = zext i1 %.sink.v to i8
1329
store i8 %frombool5, ptr %output, align 1

llvm/test/CodeGen/NVPTX/i1-select.ll

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify %}
4+
5+
target triple = "nvptx-nvidia-cuda"
6+
7+
define i32 @test_select_i1_trunc(i32 %a, i32 %b, i32 %c, i32 %true, i32 %false) {
8+
; CHECK-LABEL: test_select_i1_trunc(
9+
; CHECK: {
10+
; CHECK-NEXT: .reg .pred %p<3>;
11+
; CHECK-NEXT: .reg .b32 %r<10>;
12+
; CHECK-EMPTY:
13+
; CHECK-NEXT: // %bb.0:
14+
; CHECK-NEXT: ld.param.u32 %r1, [test_select_i1_trunc_param_0];
15+
; CHECK-NEXT: and.b32 %r2, %r1, 1;
16+
; CHECK-NEXT: setp.ne.b32 %p1, %r2, 0;
17+
; CHECK-NEXT: ld.param.u32 %r3, [test_select_i1_trunc_param_1];
18+
; CHECK-NEXT: ld.param.u32 %r4, [test_select_i1_trunc_param_2];
19+
; CHECK-NEXT: ld.param.u32 %r5, [test_select_i1_trunc_param_3];
20+
; CHECK-NEXT: selp.b32 %r6, %r3, %r4, %p1;
21+
; CHECK-NEXT: and.b32 %r7, %r6, 1;
22+
; CHECK-NEXT: setp.ne.b32 %p2, %r7, 0;
23+
; CHECK-NEXT: ld.param.u32 %r8, [test_select_i1_trunc_param_4];
24+
; CHECK-NEXT: selp.b32 %r9, %r5, %r8, %p2;
25+
; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
26+
; CHECK-NEXT: ret;
27+
%a_trunc = trunc i32 %a to i1
28+
%b_trunc = trunc i32 %b to i1
29+
%c_trunc = trunc i32 %c to i1
30+
%select_i1 = select i1 %a_trunc, i1 %b_trunc, i1 %c_trunc
31+
%select_ret = select i1 %select_i1, i32 %true, i32 %false
32+
ret i32 %select_ret
33+
}
34+
35+
define i32 @test_select_i1_trunc_2(i64 %a, i16 %b, i32 %c, i32 %true, i32 %false) {
36+
; CHECK-LABEL: test_select_i1_trunc_2(
37+
; CHECK: {
38+
; CHECK-NEXT: .reg .pred %p<3>;
39+
; CHECK-NEXT: .reg .b16 %rs<5>;
40+
; CHECK-NEXT: .reg .b32 %r<4>;
41+
; CHECK-NEXT: .reg .b64 %rd<3>;
42+
; CHECK-EMPTY:
43+
; CHECK-NEXT: // %bb.0:
44+
; CHECK-NEXT: ld.param.u64 %rd1, [test_select_i1_trunc_2_param_0];
45+
; CHECK-NEXT: and.b64 %rd2, %rd1, 1;
46+
; CHECK-NEXT: setp.ne.b64 %p1, %rd2, 0;
47+
; CHECK-NEXT: ld.param.u16 %rs1, [test_select_i1_trunc_2_param_1];
48+
; CHECK-NEXT: ld.param.u16 %rs2, [test_select_i1_trunc_2_param_2];
49+
; CHECK-NEXT: ld.param.u32 %r1, [test_select_i1_trunc_2_param_3];
50+
; CHECK-NEXT: selp.b16 %rs3, %rs1, %rs2, %p1;
51+
; CHECK-NEXT: and.b16 %rs4, %rs3, 1;
52+
; CHECK-NEXT: setp.ne.b16 %p2, %rs4, 0;
53+
; CHECK-NEXT: ld.param.u32 %r2, [test_select_i1_trunc_2_param_4];
54+
; CHECK-NEXT: selp.b32 %r3, %r1, %r2, %p2;
55+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
56+
; CHECK-NEXT: ret;
57+
%a_trunc = trunc i64 %a to i1
58+
%b_trunc = trunc i16 %b to i1
59+
%c_trunc = trunc i32 %c to i1
60+
%select_i1 = select i1 %a_trunc, i1 %b_trunc, i1 %c_trunc
61+
%select_ret = select i1 %select_i1, i32 %true, i32 %false
62+
ret i32 %select_ret
63+
}
64+
65+
define i32 @test_select_i1_basic(i32 %v1, i32 %v2, i32 %v3, i32 %true, i32 %false) {
66+
; CHECK-LABEL: test_select_i1_basic(
67+
; CHECK: {
68+
; CHECK-NEXT: .reg .pred %p<4>;
69+
; CHECK-NEXT: .reg .b32 %r<12>;
70+
; CHECK-EMPTY:
71+
; CHECK-NEXT: // %bb.0:
72+
; CHECK-NEXT: ld.param.u32 %r1, [test_select_i1_basic_param_0];
73+
; CHECK-NEXT: ld.param.u32 %r2, [test_select_i1_basic_param_1];
74+
; CHECK-NEXT: or.b32 %r4, %r1, %r2;
75+
; CHECK-NEXT: setp.ne.s32 %p1, %r1, 0;
76+
; CHECK-NEXT: ld.param.u32 %r5, [test_select_i1_basic_param_2];
77+
; CHECK-NEXT: setp.eq.s32 %p2, %r5, 0;
78+
; CHECK-NEXT: ld.param.u32 %r7, [test_select_i1_basic_param_3];
79+
; CHECK-NEXT: setp.eq.s32 %p3, %r4, 0;
80+
; CHECK-NEXT: ld.param.u32 %r8, [test_select_i1_basic_param_4];
81+
; CHECK-NEXT: selp.b32 %r9, %r7, %r8, %p2;
82+
; CHECK-NEXT: selp.b32 %r10, %r9, %r8, %p1;
83+
; CHECK-NEXT: selp.b32 %r11, %r7, %r10, %p3;
84+
; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
85+
; CHECK-NEXT: ret;
86+
%b1 = icmp eq i32 %v1, 0
87+
%b2 = icmp eq i32 %v2, 0
88+
%b3 = icmp eq i32 %v3, 0
89+
%select_i1 = select i1 %b1, i1 %b2, i1 %b3
90+
%select_ret = select i1 %select_i1, i32 %true, i32 %false
91+
ret i32 %select_ret
92+
}
93+
94+
define i32 @test_select_i1_basic_folding(i32 %v1, i32 %v2, i32 %v3, i32 %true, i32 %false) {
95+
; CHECK-LABEL: test_select_i1_basic_folding(
96+
; CHECK: {
97+
; CHECK-NEXT: .reg .pred %p<13>;
98+
; CHECK-NEXT: .reg .b32 %r<7>;
99+
; CHECK-EMPTY:
100+
; CHECK-NEXT: // %bb.0:
101+
; CHECK-NEXT: ld.param.u32 %r1, [test_select_i1_basic_folding_param_0];
102+
; CHECK-NEXT: setp.eq.s32 %p1, %r1, 0;
103+
; CHECK-NEXT: ld.param.u32 %r2, [test_select_i1_basic_folding_param_1];
104+
; CHECK-NEXT: setp.ne.s32 %p2, %r2, 0;
105+
; CHECK-NEXT: setp.eq.s32 %p3, %r2, 0;
106+
; CHECK-NEXT: ld.param.u32 %r3, [test_select_i1_basic_folding_param_2];
107+
; CHECK-NEXT: setp.eq.s32 %p4, %r3, 0;
108+
; CHECK-NEXT: ld.param.u32 %r4, [test_select_i1_basic_folding_param_3];
109+
; CHECK-NEXT: xor.pred %p6, %p1, %p3;
110+
; CHECK-NEXT: ld.param.u32 %r5, [test_select_i1_basic_folding_param_4];
111+
; CHECK-NEXT: and.pred %p7, %p6, %p4;
112+
; CHECK-NEXT: and.pred %p9, %p2, %p4;
113+
; CHECK-NEXT: and.pred %p10, %p3, %p7;
114+
; CHECK-NEXT: or.pred %p11, %p10, %p9;
115+
; CHECK-NEXT: xor.pred %p12, %p11, %p3;
116+
; CHECK-NEXT: selp.b32 %r6, %r4, %r5, %p12;
117+
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
118+
; CHECK-NEXT: ret;
119+
%b1 = icmp eq i32 %v1, 0
120+
%b2 = icmp eq i32 %v2, 0
121+
%b3 = icmp eq i32 %v3, 0
122+
%b4 = xor i1 %b1, %b2
123+
%b5 = and i1 %b4, %b3
124+
%select_i1 = select i1 %b2, i1 %b5, i1 %b3
125+
%b6 = xor i1 %select_i1, %b2
126+
%select_ret = select i1 %b6, i32 %true, i32 %false
127+
ret i32 %select_ret
128+
}

0 commit comments

Comments
 (0)