Skip to content

Commit 4263b2e

Browse files
authored
[NVPTX] Expand EXTLOAD for v8f16 and v8bf16 (#72672)
In triton-lang/triton#2483 I've encountered a bug in the NVPTX codegen. Given `load<8 x half>` followed by `fpext to <8 x float>` we get ``` ld.shared.v4.b16 {%f1, %f2, %f3, %f4}, [%r15+8]; ld.shared.v4.b16 {%f5, %f6, %f7, %f8}, [%r15]; ``` Which loads float16 values into float registers without any conversion and the result is simply garbage. This PR brings `v8f16` and `v8bf16` into line with the other vector types by expanding it to load + cvt. cc @manman-ren @Artem-B @jlebar
1 parent bfbfd1c commit 4263b2e

File tree

3 files changed

+82
-12
lines changed

3 files changed

+82
-12
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
606606
setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
607607
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
608608
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
609+
setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
610+
setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
611+
setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
612+
setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
609613
// Turn FP truncstore into trunc + store.
610614
// FIXME: vector types should also be expanded
611615
setTruncStoreAction(MVT::f32, MVT::f16, Expand);

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,23 @@ define bfloat @test_select_cc_bf16_f64(double %a, double %b, bfloat %c, bfloat %
207207
%r = select i1 %cc, bfloat %c, bfloat %d
208208
ret bfloat %r
209209
}
210+
211+
; CHECK-LABEL: test_extload_bf16x8
212+
; CHECK: ld.shared.v4.b32 {%r
213+
; CHECK: mov.b32 {%rs
214+
; CHECK: mov.b32 {%rs
215+
; CHECK: mov.b32 {%rs
216+
; CHECK: mov.b32 {%rs
217+
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
218+
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
219+
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
220+
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
221+
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
222+
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
223+
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
224+
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
225+
define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
226+
%load = load <8 x bfloat>, ptr addrspace(3) %arg, align 16
227+
%res = fpext <8 x bfloat> %load to <8 x float>
228+
ret <8 x float> %res
229+
}

llvm/test/CodeGen/NVPTX/vector-loads.ll

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,20 @@ define void @foo_complex(ptr nocapture readonly align 16 dereferenceable(1342177
9999

100100
; CHECK-LABEL: extv8f16_global_a16(
101101
define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
102-
; CHECK: ld.global.v4.b16 {%f
103-
; CHECK: ld.global.v4.b16 {%f
102+
; CHECK: ld.global.v4.b32 {%r
104103
%v = load <8 x half>, ptr addrspace(1) %src, align 16
104+
; CHECK: mov.b32 {%rs
105+
; CHECK: mov.b32 {%rs
106+
; CHECK: mov.b32 {%rs
107+
; CHECK: mov.b32 {%rs
108+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
109+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
110+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
111+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
112+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
113+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
114+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
115+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
105116
%ext = fpext <8 x half> %v to <8 x float>
106117
; CHECK: st.global.v4.f32
107118
; CHECK: st.global.v4.f32
@@ -111,11 +122,23 @@ define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst
111122

112123
; CHECK-LABEL: extv8f16_global_a4(
113124
define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
114-
; CHECK: ld.global.v2.b16 {%f
115-
; CHECK: ld.global.v2.b16 {%f
116-
; CHECK: ld.global.v2.b16 {%f
117-
; CHECK: ld.global.v2.b16 {%f
125+
; CHECK: ld.global.b32 %r
126+
; CHECK: ld.global.b32 %r
127+
; CHECK: ld.global.b32 %r
128+
; CHECK: ld.global.b32 %r
118129
%v = load <8 x half>, ptr addrspace(1) %src, align 4
130+
; CHECK: mov.b32 {%rs
131+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
132+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
133+
; CHECK: mov.b32 {%rs
134+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
135+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
136+
; CHECK: mov.b32 {%rs
137+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
138+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
139+
; CHECK: mov.b32 {%rs
140+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
141+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
119142
%ext = fpext <8 x half> %v to <8 x float>
120143
; CHECK: st.global.v4.f32
121144
; CHECK: st.global.v4.f32
@@ -126,9 +149,20 @@ define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst,
126149

127150
; CHECK-LABEL: extv8f16_generic_a16(
128151
define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
129-
; CHECK: ld.v4.b16 {%f
130-
; CHECK: ld.v4.b16 {%f
152+
; CHECK: ld.v4.b32 {%r
131153
%v = load <8 x half>, ptr %src, align 16
154+
; CHECK: mov.b32 {%rs
155+
; CHECK: mov.b32 {%rs
156+
; CHECK: mov.b32 {%rs
157+
; CHECK: mov.b32 {%rs
158+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
159+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
160+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
161+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
162+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
163+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
164+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
165+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
132166
%ext = fpext <8 x half> %v to <8 x float>
133167
; CHECK: st.v4.f32
134168
; CHECK: st.v4.f32
@@ -138,11 +172,23 @@ define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalia
138172

139173
; CHECK-LABEL: extv8f16_generic_a4(
140174
define void @extv8f16_generic_a4(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
141-
; CHECK: ld.v2.b16 {%f
142-
; CHECK: ld.v2.b16 {%f
143-
; CHECK: ld.v2.b16 {%f
144-
; CHECK: ld.v2.b16 {%f
175+
; CHECK: ld.b32 %r
176+
; CHECK: ld.b32 %r
177+
; CHECK: ld.b32 %r
178+
; CHECK: ld.b32 %r
145179
%v = load <8 x half>, ptr %src, align 4
180+
; CHECK: mov.b32 {%rs
181+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
182+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
183+
; CHECK: mov.b32 {%rs
184+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
185+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
186+
; CHECK: mov.b32 {%rs
187+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
188+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
189+
; CHECK: mov.b32 {%rs
190+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
191+
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
146192
%ext = fpext <8 x half> %v to <8 x float>
147193
; CHECK: st.v4.f32
148194
; CHECK: st.v4.f32

0 commit comments

Comments
 (0)