Skip to content

Commit 6115b1b

Browse files
authored
[NVPTX] Add lowering for bitcasts float<->v4i8 (#69960)
.. and move bitcast from a constant for integer-based types into a better suited location. It solves the mystery of why we sometimes used `mov.u32` and sometimes `mov.b32` for loading constants. Now they all should use `.b32`
1 parent 853fb0a commit 6115b1b

File tree

6 files changed

+64
-33
lines changed

6 files changed

+64
-33
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3105,9 +3105,7 @@ def BITCONVERT_32_F2I : F_BITCONVERT<"32", f32, i32>;
31053105
def BITCONVERT_64_I2F : F_BITCONVERT<"64", i64, f64>;
31063106
def BITCONVERT_64_F2I : F_BITCONVERT<"64", f64, i64>;
31073107

3108-
foreach vt = [v2f16, v2bf16, v2i16] in {
3109-
def: Pat<(vt (bitconvert (i32 UInt32Const:$a))),
3110-
(IMOVB32ri UInt32Const:$a)>;
3108+
foreach vt = [v2f16, v2bf16, v2i16, v4i8] in {
31113109
def: Pat<(vt (bitconvert (f32 Float32Regs:$a))),
31123110
(BITCONVERT_32_F2I Float32Regs:$a)>;
31133111
def: Pat<(f32 (bitconvert (vt Int32Regs:$a))),
@@ -3123,6 +3121,8 @@ def: Pat<(i16 (bitconvert (vt Int16Regs:$a))),
31233121
}
31243122

31253123
foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in {
3124+
def: Pat<(ta (bitconvert (i32 UInt32Const:$a))),
3125+
(IMOVB32ri UInt32Const:$a)>;
31263126
foreach tb = [v2f16, v2bf16, v2i16, v4i8, i32] in {
31273127
if !ne(ta, tb) then {
31283128
def: Pat<(ta (bitconvert (tb Int32Regs:$a))),

llvm/test/CodeGen/NVPTX/access-non-generic.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ define void @nested_const_expr() {
107107
; PTX-LABEL: nested_const_expr(
108108
; store 1 to bitcast(gep(addrspacecast(array), 0, 1))
109109
store i32 1, ptr getelementptr ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i64 0, i64 1), align 4
110-
; PTX: mov.u32 %r1, 1;
110+
; PTX: mov.b32 %r1, 1;
111111
; PTX-NEXT: st.shared.u32 [array+4], %r1;
112112
ret void
113113
}

llvm/test/CodeGen/NVPTX/i8x4-instructions.ll

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
1414
define <4 x i8> @test_ret_const() #0 {
1515
; CHECK-LABEL: test_ret_const(
1616
; CHECK: {
17-
; CHECK-NEXT: .reg .b32 %r<3>;
17+
; CHECK-NEXT: .reg .b32 %r<2>;
1818
; CHECK-EMPTY:
1919
; CHECK-NEXT: // %bb.0:
20-
; CHECK-NEXT: mov.u32 %r1, -66911489;
20+
; CHECK-NEXT: mov.b32 %r1, -66911489;
2121
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
2222
; CHECK-NEXT: ret;
2323
ret <4 x i8> <i8 -1, i8 2, i8 3, i8 -4>
@@ -1110,40 +1110,71 @@ define <4 x i64> @test_zext_2xi64(<4 x i8> %a) #0 {
11101110
ret <4 x i64> %r
11111111
}
11121112

1113-
define <4 x i8> @test_bitcast_i32_to_2xi8(i32 %a) #0 {
1114-
; CHECK-LABEL: test_bitcast_i32_to_2xi8(
1113+
define <4 x i8> @test_bitcast_i32_to_4xi8(i32 %a) #0 {
1114+
; CHECK-LABEL: test_bitcast_i32_to_4xi8(
11151115
; CHECK: {
11161116
; CHECK-NEXT: .reg .b32 %r<3>;
11171117
; CHECK-EMPTY:
11181118
; CHECK-NEXT: // %bb.0:
1119-
; CHECK-NEXT: ld.param.u32 %r1, [test_bitcast_i32_to_2xi8_param_0];
1119+
; CHECK-NEXT: ld.param.u32 %r1, [test_bitcast_i32_to_4xi8_param_0];
11201120
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
11211121
; CHECK-NEXT: ret;
11221122
%r = bitcast i32 %a to <4 x i8>
11231123
ret <4 x i8> %r
11241124
}
11251125

1126-
define i32 @test_bitcast_2xi8_to_i32(<4 x i8> %a) #0 {
1127-
; CHECK-LABEL: test_bitcast_2xi8_to_i32(
1126+
define <4 x i8> @test_bitcast_float_to_4xi8(float %a) #0 {
1127+
; CHECK-LABEL: test_bitcast_float_to_4xi8(
1128+
; CHECK: {
1129+
; CHECK-NEXT: .reg .b32 %r<2>;
1130+
; CHECK-NEXT: .reg .f32 %f<2>;
1131+
; CHECK-EMPTY:
1132+
; CHECK-NEXT: // %bb.0:
1133+
; CHECK-NEXT: ld.param.f32 %f1, [test_bitcast_float_to_4xi8_param_0];
1134+
; CHECK-NEXT: mov.b32 %r1, %f1;
1135+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
1136+
; CHECK-NEXT: ret;
1137+
%r = bitcast float %a to <4 x i8>
1138+
ret <4 x i8> %r
1139+
}
1140+
1141+
define i32 @test_bitcast_4xi8_to_i32(<4 x i8> %a) #0 {
1142+
; CHECK-LABEL: test_bitcast_4xi8_to_i32(
11281143
; CHECK: {
11291144
; CHECK-NEXT: .reg .b32 %r<3>;
11301145
; CHECK-EMPTY:
11311146
; CHECK-NEXT: // %bb.0:
1132-
; CHECK-NEXT: ld.param.u32 %r2, [test_bitcast_2xi8_to_i32_param_0];
1147+
; CHECK-NEXT: ld.param.u32 %r2, [test_bitcast_4xi8_to_i32_param_0];
11331148
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r2;
11341149
; CHECK-NEXT: ret;
11351150
%r = bitcast <4 x i8> %a to i32
11361151
ret i32 %r
11371152
}
11381153

1139-
define <2 x half> @test_bitcast_2xi8_to_2xhalf(i8 %a) #0 {
1140-
; CHECK-LABEL: test_bitcast_2xi8_to_2xhalf(
1154+
define float @test_bitcast_4xi8_to_float(<4 x i8> %a) #0 {
1155+
; CHECK-LABEL: test_bitcast_4xi8_to_float(
1156+
; CHECK: {
1157+
; CHECK-NEXT: .reg .b32 %r<3>;
1158+
; CHECK-NEXT: .reg .f32 %f<2>;
1159+
; CHECK-EMPTY:
1160+
; CHECK-NEXT: // %bb.0:
1161+
; CHECK-NEXT: ld.param.u32 %r2, [test_bitcast_4xi8_to_float_param_0];
1162+
; CHECK-NEXT: mov.b32 %f1, %r2;
1163+
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f1;
1164+
; CHECK-NEXT: ret;
1165+
%r = bitcast <4 x i8> %a to float
1166+
ret float %r
1167+
}
1168+
1169+
1170+
define <2 x half> @test_bitcast_4xi8_to_2xhalf(i8 %a) #0 {
1171+
; CHECK-LABEL: test_bitcast_4xi8_to_2xhalf(
11411172
; CHECK: {
11421173
; CHECK-NEXT: .reg .b16 %rs<2>;
11431174
; CHECK-NEXT: .reg .b32 %r<6>;
11441175
; CHECK-EMPTY:
11451176
; CHECK-NEXT: // %bb.0:
1146-
; CHECK-NEXT: ld.param.u8 %rs1, [test_bitcast_2xi8_to_2xhalf_param_0];
1177+
; CHECK-NEXT: ld.param.u8 %rs1, [test_bitcast_4xi8_to_2xhalf_param_0];
11471178
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
11481179
; CHECK-NEXT: bfi.b32 %r2, 5, %r1, 8, 8;
11491180
; CHECK-NEXT: bfi.b32 %r3, 6, %r2, 16, 8;
@@ -1207,14 +1238,14 @@ define <4 x i8> @test_insertelement(<4 x i8> %a, i8 %x) #0 {
12071238
ret <4 x i8> %i
12081239
}
12091240

1210-
define <4 x i8> @test_fptosi_2xhalf_to_2xi8(<4 x half> %a) #0 {
1211-
; CHECK-LABEL: test_fptosi_2xhalf_to_2xi8(
1241+
define <4 x i8> @test_fptosi_4xhalf_to_4xi8(<4 x half> %a) #0 {
1242+
; CHECK-LABEL: test_fptosi_4xhalf_to_4xi8(
12121243
; CHECK: {
12131244
; CHECK-NEXT: .reg .b16 %rs<13>;
12141245
; CHECK-NEXT: .reg .b32 %r<15>;
12151246
; CHECK-EMPTY:
12161247
; CHECK-NEXT: // %bb.0:
1217-
; CHECK-NEXT: ld.param.v2.u32 {%r3, %r4}, [test_fptosi_2xhalf_to_2xi8_param_0];
1248+
; CHECK-NEXT: ld.param.v2.u32 {%r3, %r4}, [test_fptosi_4xhalf_to_4xi8_param_0];
12181249
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r3;
12191250
; CHECK-NEXT: cvt.rzi.s16.f16 %rs3, %rs2;
12201251
; CHECK-NEXT: cvt.rzi.s16.f16 %rs4, %rs1;
@@ -1238,14 +1269,14 @@ define <4 x i8> @test_fptosi_2xhalf_to_2xi8(<4 x half> %a) #0 {
12381269
ret <4 x i8> %r
12391270
}
12401271

1241-
define <4 x i8> @test_fptoui_2xhalf_to_2xi8(<4 x half> %a) #0 {
1242-
; CHECK-LABEL: test_fptoui_2xhalf_to_2xi8(
1272+
define <4 x i8> @test_fptoui_4xhalf_to_4xi8(<4 x half> %a) #0 {
1273+
; CHECK-LABEL: test_fptoui_4xhalf_to_4xi8(
12431274
; CHECK: {
12441275
; CHECK-NEXT: .reg .b16 %rs<13>;
12451276
; CHECK-NEXT: .reg .b32 %r<15>;
12461277
; CHECK-EMPTY:
12471278
; CHECK-NEXT: // %bb.0:
1248-
; CHECK-NEXT: ld.param.v2.u32 {%r3, %r4}, [test_fptoui_2xhalf_to_2xi8_param_0];
1279+
; CHECK-NEXT: ld.param.v2.u32 {%r3, %r4}, [test_fptoui_4xhalf_to_4xi8_param_0];
12491280
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r3;
12501281
; CHECK-NEXT: cvt.rzi.u16.f16 %rs3, %rs2;
12511282
; CHECK-NEXT: cvt.rzi.u16.f16 %rs4, %rs1;

llvm/test/CodeGen/NVPTX/named-barriers.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
; Use bar.sync to arrive at a pre-computed barrier number and
77
; wait for all threads in CTA to also arrive:
88
define ptx_device void @test_barrier_named_cta() {
9-
; CHECK: mov.u32 %r[[REG0:[0-9]+]], 0;
9+
; CHECK: mov.b32 %r[[REG0:[0-9]+]], 0;
1010
; CHECK: bar.sync %r[[REG0]];
11-
; CHECK: mov.u32 %r[[REG1:[0-9]+]], 10;
11+
; CHECK: mov.b32 %r[[REG1:[0-9]+]], 10;
1212
; CHECK: bar.sync %r[[REG1]];
13-
; CHECK: mov.u32 %r[[REG2:[0-9]+]], 15;
13+
; CHECK: mov.b32 %r[[REG2:[0-9]+]], 15;
1414
; CHECK: bar.sync %r[[REG2]];
1515
; CHECK: ret;
1616
call void @llvm.nvvm.barrier.n(i32 0)
@@ -22,14 +22,14 @@ define ptx_device void @test_barrier_named_cta() {
2222
; Use bar.sync to arrive at a pre-computed barrier number and
2323
; wait for fixed number of cooperating threads to arrive:
2424
define ptx_device void @test_barrier_named() {
25-
; CHECK: mov.u32 %r[[REG0A:[0-9]+]], 32;
26-
; CHECK: mov.u32 %r[[REG0B:[0-9]+]], 0;
25+
; CHECK: mov.b32 %r[[REG0A:[0-9]+]], 32;
26+
; CHECK: mov.b32 %r[[REG0B:[0-9]+]], 0;
2727
; CHECK: bar.sync %r[[REG0B]], %r[[REG0A]];
28-
; CHECK: mov.u32 %r[[REG1A:[0-9]+]], 352;
29-
; CHECK: mov.u32 %r[[REG1B:[0-9]+]], 10;
28+
; CHECK: mov.b32 %r[[REG1A:[0-9]+]], 352;
29+
; CHECK: mov.b32 %r[[REG1B:[0-9]+]], 10;
3030
; CHECK: bar.sync %r[[REG1B]], %r[[REG1A]];
31-
; CHECK: mov.u32 %r[[REG2A:[0-9]+]], 992;
32-
; CHECK: mov.u32 %r[[REG2B:[0-9]+]], 15;
31+
; CHECK: mov.b32 %r[[REG2A:[0-9]+]], 992;
32+
; CHECK: mov.b32 %r[[REG2B:[0-9]+]], 15;
3333
; CHECK: bar.sync %r[[REG2B]], %r[[REG2A]];
3434
; CHECK: ret;
3535
call void @llvm.nvvm.barrier(i32 0, i32 32)

llvm/test/CodeGen/NVPTX/reg-types.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ entry:
4343
; CHECK: mov.u16 [[R4:%rs[0-9]]], 4;
4444
; CHECK-NEXT: st.u16 {{.*}}, [[R4]]
4545
store i32 5, ptr %s32, align 4
46-
; CHECK: mov.u32 [[R5:%r[0-9]]], 5;
46+
; CHECK: mov.b32 [[R5:%r[0-9]]], 5;
4747
; CHECK-NEXT: st.u32 {{.*}}, [[R5]]
4848
store i32 6, ptr %u32, align 4
49-
; CHECK: mov.u32 [[R6:%r[0-9]]], 6;
49+
; CHECK: mov.b32 [[R6:%r[0-9]]], 6;
5050
; CHECK-NEXT: st.u32 {{.*}}, [[R6]]
5151
store i64 7, ptr %s64, align 8
5252
; CHECK: mov.u64 [[R7:%rd[0-9]]], 7;

llvm/test/CodeGen/NVPTX/shift-parts.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
; CHECK: shift_parts_left_128
55
define void @shift_parts_left_128(ptr %val, ptr %amtptr) {
66
; CHECK: shl.b64
7-
; CHECK: mov.u32
7+
; CHECK: mov.b32
88
; CHECK: sub.s32
99
; CHECK: shr.u64
1010
; CHECK: or.b64

0 commit comments

Comments
 (0)