Skip to content

Commit ff0f201

Browse files
authored
[RISCV] Bitcast fixed length bf16/f16 build_vector to i16 with Zvfbfmin/Zvfhmin+Zfbfmin/Zfhmin. (#106637)
Previously, if Zfbfmin/Zfhmin were enabled, we only handled build_vectors that could be turned into splat_vectors. We promoted them to f32 splats by extending in the scalar domain and narrowing in the vector domain. This patch fixes a crash where we failed to account for whether the f32 vector type fit in LMUL<=8. Because the new lowering occurs after type legalization, we have to be careful to use XLenVT for the scalar integer type and use custom cast nodes.
1 parent 48bc8b0 commit ff0f201

17 files changed

+1667
-1595
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
13221322

13231323
if (VT.getVectorElementType() == MVT::f16 &&
13241324
!Subtarget.hasVInstructionsF16()) {
1325+
setOperationAction(ISD::BITCAST, VT, Custom);
13251326
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
13261327
setOperationAction(
13271328
{ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT,
@@ -1331,8 +1332,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
13311332
VT, Custom);
13321333
setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
13331334
if (Subtarget.hasStdExtZfhmin()) {
1334-
// FIXME: We should prefer BUILD_VECTOR over SPLAT_VECTOR.
1335-
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1335+
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
13361336
} else {
13371337
// We need to custom legalize f16 build vectors if Zfhmin isn't
13381338
// available.
@@ -1350,10 +1350,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
13501350
}
13511351

13521352
if (VT.getVectorElementType() == MVT::bf16) {
1353+
setOperationAction(ISD::BITCAST, VT, Custom);
13531354
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
13541355
if (Subtarget.hasStdExtZfbfmin()) {
1355-
// FIXME: We should prefer BUILD_VECTOR over SPLAT_VECTOR.
1356-
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1356+
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
13571357
} else {
13581358
// We need to custom legalize bf16 build vectors if Zfbfmin isn't
13591359
// available.
@@ -4120,38 +4120,54 @@ static SDValue lowerBuildVectorViaPacking(SDValue Op, SelectionDAG &DAG,
41204120
DAG.getBuildVector(WideVecVT, DL, NewOperands));
41214121
}
41224122

4123-
// Convert to an vXf16 build_vector to vXi16 with bitcasts.
4124-
static SDValue lowerBUILD_VECTORvXf16(SDValue Op, SelectionDAG &DAG) {
4125-
MVT VT = Op.getSimpleValueType();
4126-
MVT IVT = VT.changeVectorElementType(MVT::i16);
4127-
SmallVector<SDValue, 16> NewOps(Op.getNumOperands());
4128-
for (unsigned I = 0, E = Op.getNumOperands(); I != E; ++I)
4129-
NewOps[I] = DAG.getBitcast(MVT::i16, Op.getOperand(I));
4130-
SDValue Res = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(Op), IVT, NewOps);
4131-
return DAG.getBitcast(VT, Res);
4132-
}
4133-
41344123
static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
41354124
const RISCVSubtarget &Subtarget) {
41364125
MVT VT = Op.getSimpleValueType();
41374126
assert(VT.isFixedLengthVector() && "Unexpected vector!");
41384127

4139-
// If we don't have scalar f16/bf16, we need to bitcast to an i16 vector.
4140-
if ((VT.getVectorElementType() == MVT::f16 && !Subtarget.hasStdExtZfhmin()) ||
4141-
(VT.getVectorElementType() == MVT::bf16 && !Subtarget.hasStdExtZfbfmin()))
4142-
return lowerBUILD_VECTORvXf16(Op, DAG);
4128+
MVT EltVT = VT.getVectorElementType();
4129+
MVT XLenVT = Subtarget.getXLenVT();
4130+
4131+
SDLoc DL(Op);
4132+
4133+
// Proper support for f16 requires Zvfh. bf16 always requires special
4134+
// handling. We need to cast the scalar to integer and create an integer
4135+
// build_vector.
4136+
if ((EltVT == MVT::f16 && !Subtarget.hasStdExtZvfh()) || EltVT == MVT::bf16) {
4137+
MVT IVT = VT.changeVectorElementType(MVT::i16);
4138+
SmallVector<SDValue, 16> NewOps(Op.getNumOperands());
4139+
for (unsigned I = 0, E = Op.getNumOperands(); I != E; ++I) {
4140+
SDValue Elem = Op.getOperand(I);
4141+
if ((EltVT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) ||
4142+
(EltVT == MVT::f16 && Subtarget.hasStdExtZfhmin())) {
4143+
// Called by LegalizeDAG, we need to use XLenVT operations since we
4144+
// can't create illegal types.
4145+
if (auto *C = dyn_cast<ConstantFPSDNode>(Elem)) {
4146+
// Manually constant fold so the integer build_vector can be lowered
4147+
// better. Waiting for DAGCombine will be too late.
4148+
APInt V =
4149+
C->getValueAPF().bitcastToAPInt().sext(XLenVT.getSizeInBits());
4150+
NewOps[I] = DAG.getConstant(V, DL, XLenVT);
4151+
} else {
4152+
NewOps[I] = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Elem);
4153+
}
4154+
} else {
4155+
// Called by scalar type legalizer, we can use i16.
4156+
NewOps[I] = DAG.getBitcast(MVT::i16, Op.getOperand(I));
4157+
}
4158+
}
4159+
SDValue Res = DAG.getNode(ISD::BUILD_VECTOR, DL, IVT, NewOps);
4160+
return DAG.getBitcast(VT, Res);
4161+
}
41434162

41444163
if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
41454164
ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
41464165
return lowerBuildVectorOfConstants(Op, DAG, Subtarget);
41474166

41484167
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
41494168

4150-
SDLoc DL(Op);
41514169
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
41524170

4153-
MVT XLenVT = Subtarget.getXLenVT();
4154-
41554171
if (VT.getVectorElementType() == MVT::i1) {
41564172
// A BUILD_VECTOR can be lowered as a SETCC. For each fixed-length mask
41574173
// vector type, we have a legal equivalently-sized i8 type, so we can use
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv32 -target-abi=ilp32d -mattr=+v,+zvfbfmin,+f,+d -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32,RV32ZVFBFMIN,RV32-NO-ZFBFMIN
3+
; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+zvfbfmin,+f,+d -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64,RV64ZVFBFMIN,RV64-NO-ZFBFMIN
4+
; RUN: llc -mtriple=riscv32 -target-abi=ilp32d -mattr=+v,+zfbfmin,+zvfbfmin,+f,+d -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32,RV32ZVFBFMIN,RV32-ZFBFMIN
5+
; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+zfbfmin,+zvfbfmin,+f,+d -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64,RV64ZVFBFMIN,RV64-ZFBFMIN
6+
7+
define <4 x bfloat> @splat_idx_v4bf16(<4 x bfloat> %v, i64 %idx) {
8+
; RV32-NO-ZFBFMIN-LABEL: splat_idx_v4bf16:
9+
; RV32-NO-ZFBFMIN: # %bb.0:
10+
; RV32-NO-ZFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
11+
; RV32-NO-ZFBFMIN-NEXT: vrgather.vx v9, v8, a0
12+
; RV32-NO-ZFBFMIN-NEXT: vmv1r.v v8, v9
13+
; RV32-NO-ZFBFMIN-NEXT: ret
14+
;
15+
; RV64-NO-ZFBFMIN-LABEL: splat_idx_v4bf16:
16+
; RV64-NO-ZFBFMIN: # %bb.0:
17+
; RV64-NO-ZFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
18+
; RV64-NO-ZFBFMIN-NEXT: vrgather.vx v9, v8, a0
19+
; RV64-NO-ZFBFMIN-NEXT: vmv1r.v v8, v9
20+
; RV64-NO-ZFBFMIN-NEXT: ret
21+
;
22+
; RV32-ZFBFMIN-LABEL: splat_idx_v4bf16:
23+
; RV32-ZFBFMIN: # %bb.0:
24+
; RV32-ZFBFMIN-NEXT: addi sp, sp, -48
25+
; RV32-ZFBFMIN-NEXT: .cfi_def_cfa_offset 48
26+
; RV32-ZFBFMIN-NEXT: sw ra, 44(sp) # 4-byte Folded Spill
27+
; RV32-ZFBFMIN-NEXT: .cfi_offset ra, -4
28+
; RV32-ZFBFMIN-NEXT: csrr a1, vlenb
29+
; RV32-ZFBFMIN-NEXT: slli a1, a1, 1
30+
; RV32-ZFBFMIN-NEXT: sub sp, sp, a1
31+
; RV32-ZFBFMIN-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x30, 0x22, 0x11, 0x02, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 48 + 2 * vlenb
32+
; RV32-ZFBFMIN-NEXT: addi a1, sp, 32
33+
; RV32-ZFBFMIN-NEXT: vs1r.v v8, (a1) # Unknown-size Folded Spill
34+
; RV32-ZFBFMIN-NEXT: andi a0, a0, 3
35+
; RV32-ZFBFMIN-NEXT: li a1, 2
36+
; RV32-ZFBFMIN-NEXT: call __mulsi3
37+
; RV32-ZFBFMIN-NEXT: addi a1, sp, 16
38+
; RV32-ZFBFMIN-NEXT: add a0, a1, a0
39+
; RV32-ZFBFMIN-NEXT: addi a2, sp, 32
40+
; RV32-ZFBFMIN-NEXT: vl1r.v v8, (a2) # Unknown-size Folded Reload
41+
; RV32-ZFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
42+
; RV32-ZFBFMIN-NEXT: vse16.v v8, (a1)
43+
; RV32-ZFBFMIN-NEXT: flh fa5, 0(a0)
44+
; RV32-ZFBFMIN-NEXT: fmv.x.h a0, fa5
45+
; RV32-ZFBFMIN-NEXT: vmv.v.x v8, a0
46+
; RV32-ZFBFMIN-NEXT: csrr a0, vlenb
47+
; RV32-ZFBFMIN-NEXT: slli a0, a0, 1
48+
; RV32-ZFBFMIN-NEXT: add sp, sp, a0
49+
; RV32-ZFBFMIN-NEXT: lw ra, 44(sp) # 4-byte Folded Reload
50+
; RV32-ZFBFMIN-NEXT: addi sp, sp, 48
51+
; RV32-ZFBFMIN-NEXT: ret
52+
;
53+
; RV64-ZFBFMIN-LABEL: splat_idx_v4bf16:
54+
; RV64-ZFBFMIN: # %bb.0:
55+
; RV64-ZFBFMIN-NEXT: addi sp, sp, -48
56+
; RV64-ZFBFMIN-NEXT: .cfi_def_cfa_offset 48
57+
; RV64-ZFBFMIN-NEXT: sd ra, 40(sp) # 8-byte Folded Spill
58+
; RV64-ZFBFMIN-NEXT: .cfi_offset ra, -8
59+
; RV64-ZFBFMIN-NEXT: csrr a1, vlenb
60+
; RV64-ZFBFMIN-NEXT: slli a1, a1, 1
61+
; RV64-ZFBFMIN-NEXT: sub sp, sp, a1
62+
; RV64-ZFBFMIN-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x30, 0x22, 0x11, 0x02, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 48 + 2 * vlenb
63+
; RV64-ZFBFMIN-NEXT: addi a1, sp, 32
64+
; RV64-ZFBFMIN-NEXT: vs1r.v v8, (a1) # Unknown-size Folded Spill
65+
; RV64-ZFBFMIN-NEXT: andi a0, a0, 3
66+
; RV64-ZFBFMIN-NEXT: li a1, 2
67+
; RV64-ZFBFMIN-NEXT: call __muldi3
68+
; RV64-ZFBFMIN-NEXT: addi a1, sp, 16
69+
; RV64-ZFBFMIN-NEXT: add a0, a1, a0
70+
; RV64-ZFBFMIN-NEXT: addi a2, sp, 32
71+
; RV64-ZFBFMIN-NEXT: vl1r.v v8, (a2) # Unknown-size Folded Reload
72+
; RV64-ZFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
73+
; RV64-ZFBFMIN-NEXT: vse16.v v8, (a1)
74+
; RV64-ZFBFMIN-NEXT: flh fa5, 0(a0)
75+
; RV64-ZFBFMIN-NEXT: fmv.x.h a0, fa5
76+
; RV64-ZFBFMIN-NEXT: vmv.v.x v8, a0
77+
; RV64-ZFBFMIN-NEXT: csrr a0, vlenb
78+
; RV64-ZFBFMIN-NEXT: slli a0, a0, 1
79+
; RV64-ZFBFMIN-NEXT: add sp, sp, a0
80+
; RV64-ZFBFMIN-NEXT: ld ra, 40(sp) # 8-byte Folded Reload
81+
; RV64-ZFBFMIN-NEXT: addi sp, sp, 48
82+
; RV64-ZFBFMIN-NEXT: ret
83+
%x = extractelement <4 x bfloat> %v, i64 %idx
84+
%ins = insertelement <4 x bfloat> poison, bfloat %x, i32 0
85+
%splat = shufflevector <4 x bfloat> %ins, <4 x bfloat> poison, <4 x i32> zeroinitializer
86+
ret <4 x bfloat> %splat
87+
}
88+
89+
define <2 x bfloat> @buildvec_v2bf16(bfloat %a, bfloat %b) {
90+
; RV32-NO-ZFBFMIN-LABEL: buildvec_v2bf16:
91+
; RV32-NO-ZFBFMIN: # %bb.0:
92+
; RV32-NO-ZFBFMIN-NEXT: fmv.x.w a0, fa1
93+
; RV32-NO-ZFBFMIN-NEXT: fmv.x.w a1, fa0
94+
; RV32-NO-ZFBFMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
95+
; RV32-NO-ZFBFMIN-NEXT: vmv.v.x v8, a1
96+
; RV32-NO-ZFBFMIN-NEXT: vslide1down.vx v8, v8, a0
97+
; RV32-NO-ZFBFMIN-NEXT: ret
98+
;
99+
; RV64-NO-ZFBFMIN-LABEL: buildvec_v2bf16:
100+
; RV64-NO-ZFBFMIN: # %bb.0:
101+
; RV64-NO-ZFBFMIN-NEXT: fmv.x.w a0, fa1
102+
; RV64-NO-ZFBFMIN-NEXT: fmv.x.w a1, fa0
103+
; RV64-NO-ZFBFMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
104+
; RV64-NO-ZFBFMIN-NEXT: vmv.v.x v8, a1
105+
; RV64-NO-ZFBFMIN-NEXT: vslide1down.vx v8, v8, a0
106+
; RV64-NO-ZFBFMIN-NEXT: ret
107+
;
108+
; RV32-ZFBFMIN-LABEL: buildvec_v2bf16:
109+
; RV32-ZFBFMIN: # %bb.0:
110+
; RV32-ZFBFMIN-NEXT: fmv.x.h a0, fa1
111+
; RV32-ZFBFMIN-NEXT: fmv.x.h a1, fa0
112+
; RV32-ZFBFMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
113+
; RV32-ZFBFMIN-NEXT: vmv.v.x v8, a1
114+
; RV32-ZFBFMIN-NEXT: vslide1down.vx v8, v8, a0
115+
; RV32-ZFBFMIN-NEXT: ret
116+
;
117+
; RV64-ZFBFMIN-LABEL: buildvec_v2bf16:
118+
; RV64-ZFBFMIN: # %bb.0:
119+
; RV64-ZFBFMIN-NEXT: fmv.x.h a0, fa1
120+
; RV64-ZFBFMIN-NEXT: fmv.x.h a1, fa0
121+
; RV64-ZFBFMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
122+
; RV64-ZFBFMIN-NEXT: vmv.v.x v8, a1
123+
; RV64-ZFBFMIN-NEXT: vslide1down.vx v8, v8, a0
124+
; RV64-ZFBFMIN-NEXT: ret
125+
%v1 = insertelement <2 x bfloat> poison, bfloat %a, i64 0
126+
%v2 = insertelement <2 x bfloat> %v1, bfloat %b, i64 1
127+
ret <2 x bfloat> %v2
128+
}
129+
130+
define <2 x bfloat> @vid_v2bf16() {
131+
; CHECK-LABEL: vid_v2bf16:
132+
; CHECK: # %bb.0:
133+
; CHECK-NEXT: lui a0, 260096
134+
; CHECK-NEXT: vsetivli zero, 2, e32, m1, ta, ma
135+
; CHECK-NEXT: vmv.s.x v8, a0
136+
; CHECK-NEXT: ret
137+
ret <2 x bfloat> <bfloat 0.0, bfloat 1.0>
138+
}
139+
140+
define <2 x bfloat> @vid_addend1_v2bf16() {
141+
; CHECK-LABEL: vid_addend1_v2bf16:
142+
; CHECK: # %bb.0:
143+
; CHECK-NEXT: lui a0, 262148
144+
; CHECK-NEXT: addi a0, a0, -128
145+
; CHECK-NEXT: vsetivli zero, 2, e32, m1, ta, ma
146+
; CHECK-NEXT: vmv.s.x v8, a0
147+
; CHECK-NEXT: ret
148+
ret <2 x bfloat> <bfloat 1.0, bfloat 2.0>
149+
}
150+
151+
define <2 x bfloat> @vid_denominator2_v2bf16() {
152+
; CHECK-LABEL: vid_denominator2_v2bf16:
153+
; CHECK: # %bb.0:
154+
; CHECK-NEXT: lui a0, 260100
155+
; CHECK-NEXT: addi a0, a0, -256
156+
; CHECK-NEXT: vsetivli zero, 2, e32, m1, ta, ma
157+
; CHECK-NEXT: vmv.s.x v8, a0
158+
; CHECK-NEXT: ret
159+
ret <2 x bfloat> <bfloat 0.5, bfloat 1.0>
160+
}
161+
162+
define <2 x bfloat> @vid_step2_v2bf16() {
163+
; CHECK-LABEL: vid_step2_v2bf16:
164+
; CHECK: # %bb.0:
165+
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
166+
; CHECK-NEXT: vid.v v8
167+
; CHECK-NEXT: vsll.vi v8, v8, 14
168+
; CHECK-NEXT: ret
169+
ret <2 x bfloat> <bfloat 0.0, bfloat 2.0>
170+
}
171+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
172+
; RV32: {{.*}}
173+
; RV32ZVFBFMIN: {{.*}}
174+
; RV64: {{.*}}
175+
; RV64ZVFBFMIN: {{.*}}

0 commit comments

Comments
 (0)