Skip to content

Commit 0ef8e71

Browse files
committed
[RISCV] Custom legalize vXbf16 BUILD_VECTOR without Zfbfmin.
By default, type legalization will try to promote the build_vector, but that generic type legalizer doesn't support that. Bitcast to vXi16 instead. Same as what we do for vXf16 without Zfhmin. Fixes #100846.
1 parent f54ae6d commit 0ef8e71

File tree

2 files changed

+122
-5
lines changed

2 files changed

+122
-5
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,8 +1285,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
12851285

12861286
if (VT.getVectorElementType() == MVT::bf16) {
12871287
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
1288-
// FIXME: We should prefer BUILD_VECTOR over SPLAT_VECTOR.
1289-
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1288+
if (Subtarget.hasStdExtZfbfmin()) {
1289+
// FIXME: We should prefer BUILD_VECTOR over SPLAT_VECTOR.
1290+
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1291+
} else {
1292+
// We need to custom legalize bf16 build vectors if Zfbfmin isn't
1293+
// available.
1294+
setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
1295+
}
12901296
setOperationAction(
12911297
{ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT,
12921298
Custom);
@@ -3935,9 +3941,9 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
39353941
MVT VT = Op.getSimpleValueType();
39363942
assert(VT.isFixedLengthVector() && "Unexpected vector!");
39373943

3938-
// If we don't have scalar f16, we need to bitcast to an i16 vector.
3939-
if (VT.getVectorElementType() == MVT::f16 &&
3940-
!Subtarget.hasStdExtZfhmin())
3944+
// If we don't have scalar f16/bf16, we need to bitcast to an i16 vector.
3945+
if ((VT.getVectorElementType() == MVT::f16 && !Subtarget.hasStdExtZfhmin()) ||
3946+
(VT.getVectorElementType() == MVT::bf16 && !Subtarget.hasStdExtZfbfmin()))
39413947
return lowerBUILD_VECTORvXf16(Op, DAG);
39423948

39433949
if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv32 -target-abi=ilp32d -mattr=+v,+zfbfmin,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZFBFMIN-ZVFBFMIN
3+
; RUN: llc -mtriple=riscv32 -target-abi=ilp32d -mattr=+v,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZVFBFMIN
4+
; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+zfbfmin,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZFBFMIN-ZVFBFMIN
5+
; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZVFBFMIN
6+
7+
define <8 x bfloat> @splat_v8bf16(ptr %x, bfloat %y) {
8+
; ZFBFMIN-ZVFBFMIN-LABEL: splat_v8bf16:
9+
; ZFBFMIN-ZVFBFMIN: # %bb.0:
10+
; ZFBFMIN-ZVFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0
11+
; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a0, zero, e32, m2, ta, ma
12+
; ZFBFMIN-ZVFBFMIN-NEXT: vfmv.v.f v10, fa5
13+
; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli zero, zero, e16, m1, ta, ma
14+
; ZFBFMIN-ZVFBFMIN-NEXT: vfncvtbf16.f.f.w v8, v10
15+
; ZFBFMIN-ZVFBFMIN-NEXT: ret
16+
;
17+
; ZVFBFMIN-LABEL: splat_v8bf16:
18+
; ZVFBFMIN: # %bb.0:
19+
; ZVFBFMIN-NEXT: fmv.x.w a0, fa0
20+
; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
21+
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
22+
; ZVFBFMIN-NEXT: ret
23+
%a = insertelement <8 x bfloat> poison, bfloat %y, i32 0
24+
%b = shufflevector <8 x bfloat> %a, <8 x bfloat> poison, <8 x i32> zeroinitializer
25+
ret <8 x bfloat> %b
26+
}
27+
28+
define <16 x bfloat> @splat_16bf16(ptr %x, bfloat %y) {
29+
; ZFBFMIN-ZVFBFMIN-LABEL: splat_16bf16:
30+
; ZFBFMIN-ZVFBFMIN: # %bb.0:
31+
; ZFBFMIN-ZVFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0
32+
; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a0, zero, e32, m4, ta, ma
33+
; ZFBFMIN-ZVFBFMIN-NEXT: vfmv.v.f v12, fa5
34+
; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli zero, zero, e16, m2, ta, ma
35+
; ZFBFMIN-ZVFBFMIN-NEXT: vfncvtbf16.f.f.w v8, v12
36+
; ZFBFMIN-ZVFBFMIN-NEXT: ret
37+
;
38+
; ZVFBFMIN-LABEL: splat_16bf16:
39+
; ZVFBFMIN: # %bb.0:
40+
; ZVFBFMIN-NEXT: fmv.x.w a0, fa0
41+
; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
42+
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
43+
; ZVFBFMIN-NEXT: ret
44+
%a = insertelement <16 x bfloat> poison, bfloat %y, i32 0
45+
%b = shufflevector <16 x bfloat> %a, <16 x bfloat> poison, <16 x i32> zeroinitializer
46+
ret <16 x bfloat> %b
47+
}
48+
49+
define <8 x bfloat> @splat_zero_v8bf16(ptr %x) {
50+
; ZFBFMIN-ZVFBFMIN-LABEL: splat_zero_v8bf16:
51+
; ZFBFMIN-ZVFBFMIN: # %bb.0:
52+
; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a0, zero, e16, m1, ta, ma
53+
; ZFBFMIN-ZVFBFMIN-NEXT: vmv.v.i v8, 0
54+
; ZFBFMIN-ZVFBFMIN-NEXT: ret
55+
;
56+
; ZVFBFMIN-LABEL: splat_zero_v8bf16:
57+
; ZVFBFMIN: # %bb.0:
58+
; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
59+
; ZVFBFMIN-NEXT: vmv.v.i v8, 0
60+
; ZVFBFMIN-NEXT: ret
61+
ret <8 x bfloat> splat (bfloat 0.0)
62+
}
63+
64+
define <16 x bfloat> @splat_zero_16bf16(ptr %x) {
65+
; ZFBFMIN-ZVFBFMIN-LABEL: splat_zero_16bf16:
66+
; ZFBFMIN-ZVFBFMIN: # %bb.0:
67+
; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a0, zero, e16, m2, ta, ma
68+
; ZFBFMIN-ZVFBFMIN-NEXT: vmv.v.i v8, 0
69+
; ZFBFMIN-ZVFBFMIN-NEXT: ret
70+
;
71+
; ZVFBFMIN-LABEL: splat_zero_16bf16:
72+
; ZVFBFMIN: # %bb.0:
73+
; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
74+
; ZVFBFMIN-NEXT: vmv.v.i v8, 0
75+
; ZVFBFMIN-NEXT: ret
76+
ret <16 x bfloat> splat (bfloat 0.0)
77+
}
78+
79+
define <8 x bfloat> @splat_negzero_v8bf16(ptr %x) {
80+
; ZFBFMIN-ZVFBFMIN-LABEL: splat_negzero_v8bf16:
81+
; ZFBFMIN-ZVFBFMIN: # %bb.0:
82+
; ZFBFMIN-ZVFBFMIN-NEXT: lui a0, 1048568
83+
; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a1, zero, e16, m1, ta, ma
84+
; ZFBFMIN-ZVFBFMIN-NEXT: vmv.v.x v8, a0
85+
; ZFBFMIN-ZVFBFMIN-NEXT: ret
86+
;
87+
; ZVFBFMIN-LABEL: splat_negzero_v8bf16:
88+
; ZVFBFMIN: # %bb.0:
89+
; ZVFBFMIN-NEXT: lui a0, 1048568
90+
; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
91+
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
92+
; ZVFBFMIN-NEXT: ret
93+
ret <8 x bfloat> splat (bfloat -0.0)
94+
}
95+
96+
define <16 x bfloat> @splat_negzero_16bf16(ptr %x) {
97+
; ZFBFMIN-ZVFBFMIN-LABEL: splat_negzero_16bf16:
98+
; ZFBFMIN-ZVFBFMIN: # %bb.0:
99+
; ZFBFMIN-ZVFBFMIN-NEXT: lui a0, 1048568
100+
; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a1, zero, e16, m2, ta, ma
101+
; ZFBFMIN-ZVFBFMIN-NEXT: vmv.v.x v8, a0
102+
; ZFBFMIN-ZVFBFMIN-NEXT: ret
103+
;
104+
; ZVFBFMIN-LABEL: splat_negzero_16bf16:
105+
; ZVFBFMIN: # %bb.0:
106+
; ZVFBFMIN-NEXT: lui a0, 1048568
107+
; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
108+
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
109+
; ZVFBFMIN-NEXT: ret
110+
ret <16 x bfloat> splat (bfloat -0.0)
111+
}

0 commit comments

Comments
 (0)