Skip to content

Commit f172c31

Browse files
authored
[RISCV] Lower memory ops and VP splat for zvfhmin and zvfbfmin (#109387)
We can lower f16/bf16 memory ops without promotion through the existing custom lowering. Some of the zero strided VP loads get combined to a VP splat, so we need to also handle the lowering for that for f16/bf16 w/ zvfhmin/zvfbfmin. This patch copies the lowering from ISD::SPLAT_VECTOR over to lowerScalarSplat which is used by the VP splat lowering.
1 parent eba21ac commit f172c31

File tree

12 files changed

+1548
-155
lines changed

12 files changed

+1548
-155
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,10 +1082,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
10821082
VT, Custom);
10831083
MVT EltVT = VT.getVectorElementType();
10841084
if (isTypeLegal(EltVT))
1085-
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1085+
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT}, VT,
1086+
Custom);
10861087
else
1087-
setOperationAction(ISD::SPLAT_VECTOR, EltVT, Custom);
1088-
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
1088+
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT},
1089+
EltVT, Custom);
1090+
setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE,
1091+
ISD::MGATHER, ISD::MSCATTER, ISD::VP_LOAD,
1092+
ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
1093+
ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER,
1094+
ISD::VP_SCATTER},
1095+
VT, Custom);
10891096

10901097
setOperationAction(ISD::FNEG, VT, Expand);
10911098
setOperationAction(ISD::FABS, VT, Expand);
@@ -4449,11 +4456,27 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
44494456
bool HasPassthru = Passthru && !Passthru.isUndef();
44504457
if (!HasPassthru && !Passthru)
44514458
Passthru = DAG.getUNDEF(VT);
4452-
if (VT.isFloatingPoint())
4453-
return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL);
44544459

4460+
MVT EltVT = VT.getVectorElementType();
44554461
MVT XLenVT = Subtarget.getXLenVT();
44564462

4463+
if (VT.isFloatingPoint()) {
4464+
if ((EltVT == MVT::f16 && !Subtarget.hasStdExtZvfh()) ||
4465+
EltVT == MVT::bf16) {
4466+
if ((EltVT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) ||
4467+
(EltVT == MVT::f16 && Subtarget.hasStdExtZfhmin()))
4468+
Scalar = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Scalar);
4469+
else
4470+
Scalar = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Scalar);
4471+
MVT IVT = VT.changeVectorElementType(MVT::i16);
4472+
Passthru = DAG.getNode(ISD::BITCAST, DL, IVT, Passthru);
4473+
SDValue Splat =
4474+
lowerScalarSplat(Passthru, Scalar, VL, IVT, DL, DAG, Subtarget);
4475+
return DAG.getNode(ISD::BITCAST, DL, VT, Splat);
4476+
}
4477+
return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL);
4478+
}
4479+
44574480
// Simplest case is that the operand needs to be promoted to XLenVT.
44584481
if (Scalar.getValueType().bitsLE(XLenVT)) {
44594482
// If the operand is a constant, sign extend to increase our chances

llvm/test/CodeGen/RISCV/rvv/masked-load-fp.ll

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
3-
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
2+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
4+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
5+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
6+
7+
define <vscale x 1 x bfloat> @masked_load_nxv1bf16(ptr %a, <vscale x 1 x i1> %mask) nounwind {
8+
; CHECK-LABEL: masked_load_nxv1bf16:
9+
; CHECK: # %bb.0:
10+
; CHECK-NEXT: vsetvli a1, zero, e16, mf4, ta, ma
11+
; CHECK-NEXT: vle16.v v8, (a0), v0.t
12+
; CHECK-NEXT: ret
13+
%load = call <vscale x 1 x bfloat> @llvm.masked.load.nxv1bf16(ptr %a, i32 2, <vscale x 1 x i1> %mask, <vscale x 1 x bfloat> undef)
14+
ret <vscale x 1 x bfloat> %load
15+
}
16+
declare <vscale x 1 x bfloat> @llvm.masked.load.nxv1bf16(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x bfloat>)
417

518
define <vscale x 1 x half> @masked_load_nxv1f16(ptr %a, <vscale x 1 x i1> %mask) nounwind {
619
; CHECK-LABEL: masked_load_nxv1f16:
@@ -35,6 +48,17 @@ define <vscale x 1 x double> @masked_load_nxv1f64(ptr %a, <vscale x 1 x i1> %mas
3548
}
3649
declare <vscale x 1 x double> @llvm.masked.load.nxv1f64(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x double>)
3750

51+
define <vscale x 2 x bfloat> @masked_load_nxv2bf16(ptr %a, <vscale x 2 x i1> %mask) nounwind {
52+
; CHECK-LABEL: masked_load_nxv2bf16:
53+
; CHECK: # %bb.0:
54+
; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, ma
55+
; CHECK-NEXT: vle16.v v8, (a0), v0.t
56+
; CHECK-NEXT: ret
57+
%load = call <vscale x 2 x bfloat> @llvm.masked.load.nxv2bf16(ptr %a, i32 2, <vscale x 2 x i1> %mask, <vscale x 2 x bfloat> undef)
58+
ret <vscale x 2 x bfloat> %load
59+
}
60+
declare <vscale x 2 x bfloat> @llvm.masked.load.nxv2bf16(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x bfloat>)
61+
3862
define <vscale x 2 x half> @masked_load_nxv2f16(ptr %a, <vscale x 2 x i1> %mask) nounwind {
3963
; CHECK-LABEL: masked_load_nxv2f16:
4064
; CHECK: # %bb.0:
@@ -68,6 +92,17 @@ define <vscale x 2 x double> @masked_load_nxv2f64(ptr %a, <vscale x 2 x i1> %mas
6892
}
6993
declare <vscale x 2 x double> @llvm.masked.load.nxv2f64(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x double>)
7094

95+
define <vscale x 4 x bfloat> @masked_load_nxv4bf16(ptr %a, <vscale x 4 x i1> %mask) nounwind {
96+
; CHECK-LABEL: masked_load_nxv4bf16:
97+
; CHECK: # %bb.0:
98+
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
99+
; CHECK-NEXT: vle16.v v8, (a0), v0.t
100+
; CHECK-NEXT: ret
101+
%load = call <vscale x 4 x bfloat> @llvm.masked.load.nxv4bf16(ptr %a, i32 2, <vscale x 4 x i1> %mask, <vscale x 4 x bfloat> undef)
102+
ret <vscale x 4 x bfloat> %load
103+
}
104+
declare <vscale x 4 x bfloat> @llvm.masked.load.nxv4bf16(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x bfloat>)
105+
71106
define <vscale x 4 x half> @masked_load_nxv4f16(ptr %a, <vscale x 4 x i1> %mask) nounwind {
72107
; CHECK-LABEL: masked_load_nxv4f16:
73108
; CHECK: # %bb.0:
@@ -101,6 +136,17 @@ define <vscale x 4 x double> @masked_load_nxv4f64(ptr %a, <vscale x 4 x i1> %mas
101136
}
102137
declare <vscale x 4 x double> @llvm.masked.load.nxv4f64(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x double>)
103138

139+
define <vscale x 8 x bfloat> @masked_load_nxv8bf16(ptr %a, <vscale x 8 x i1> %mask) nounwind {
140+
; CHECK-LABEL: masked_load_nxv8bf16:
141+
; CHECK: # %bb.0:
142+
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, ma
143+
; CHECK-NEXT: vle16.v v8, (a0), v0.t
144+
; CHECK-NEXT: ret
145+
%load = call <vscale x 8 x bfloat> @llvm.masked.load.nxv8bf16(ptr %a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x bfloat> undef)
146+
ret <vscale x 8 x bfloat> %load
147+
}
148+
declare <vscale x 8 x bfloat> @llvm.masked.load.nxv8bf16(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x bfloat>)
149+
104150
define <vscale x 8 x half> @masked_load_nxv8f16(ptr %a, <vscale x 8 x i1> %mask) nounwind {
105151
; CHECK-LABEL: masked_load_nxv8f16:
106152
; CHECK: # %bb.0:
@@ -134,6 +180,17 @@ define <vscale x 8 x double> @masked_load_nxv8f64(ptr %a, <vscale x 8 x i1> %mas
134180
}
135181
declare <vscale x 8 x double> @llvm.masked.load.nxv8f64(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x double>)
136182

183+
define <vscale x 16 x bfloat> @masked_load_nxv16bf16(ptr %a, <vscale x 16 x i1> %mask) nounwind {
184+
; CHECK-LABEL: masked_load_nxv16bf16:
185+
; CHECK: # %bb.0:
186+
; CHECK-NEXT: vsetvli a1, zero, e16, m4, ta, ma
187+
; CHECK-NEXT: vle16.v v8, (a0), v0.t
188+
; CHECK-NEXT: ret
189+
%load = call <vscale x 16 x bfloat> @llvm.masked.load.nxv16bf16(ptr %a, i32 2, <vscale x 16 x i1> %mask, <vscale x 16 x bfloat> undef)
190+
ret <vscale x 16 x bfloat> %load
191+
}
192+
declare <vscale x 16 x bfloat> @llvm.masked.load.nxv16bf16(ptr, i32, <vscale x 16 x i1>, <vscale x 16 x bfloat>)
193+
137194
define <vscale x 16 x half> @masked_load_nxv16f16(ptr %a, <vscale x 16 x i1> %mask) nounwind {
138195
; CHECK-LABEL: masked_load_nxv16f16:
139196
; CHECK: # %bb.0:
@@ -156,6 +213,17 @@ define <vscale x 16 x float> @masked_load_nxv16f32(ptr %a, <vscale x 16 x i1> %m
156213
}
157214
declare <vscale x 16 x float> @llvm.masked.load.nxv16f32(ptr, i32, <vscale x 16 x i1>, <vscale x 16 x float>)
158215

216+
define <vscale x 32 x bfloat> @masked_load_nxv32bf16(ptr %a, <vscale x 32 x i1> %mask) nounwind {
217+
; CHECK-LABEL: masked_load_nxv32bf16:
218+
; CHECK: # %bb.0:
219+
; CHECK-NEXT: vsetvli a1, zero, e16, m8, ta, ma
220+
; CHECK-NEXT: vle16.v v8, (a0), v0.t
221+
; CHECK-NEXT: ret
222+
%load = call <vscale x 32 x bfloat> @llvm.masked.load.nxv32bf16(ptr %a, i32 2, <vscale x 32 x i1> %mask, <vscale x 32 x bfloat> undef)
223+
ret <vscale x 32 x bfloat> %load
224+
}
225+
declare <vscale x 32 x bfloat> @llvm.masked.load.nxv32bf16(ptr, i32, <vscale x 32 x i1>, <vscale x 32 x bfloat>)
226+
159227
define <vscale x 32 x half> @masked_load_nxv32f16(ptr %a, <vscale x 32 x i1> %mask) nounwind {
160228
; CHECK-LABEL: masked_load_nxv32f16:
161229
; CHECK: # %bb.0:

llvm/test/CodeGen/RISCV/rvv/masked-store-fp.ll

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
3-
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
2+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
4+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
5+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
6+
7+
define void @masked_store_nxv1bf16(<vscale x 1 x bfloat> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
8+
; CHECK-LABEL: masked_store_nxv1bf16:
9+
; CHECK: # %bb.0:
10+
; CHECK-NEXT: vsetvli a1, zero, e16, mf4, ta, ma
11+
; CHECK-NEXT: vse16.v v8, (a0), v0.t
12+
; CHECK-NEXT: ret
13+
call void @llvm.masked.store.nxv1bf16.p0(<vscale x 1 x bfloat> %val, ptr %a, i32 2, <vscale x 1 x i1> %mask)
14+
ret void
15+
}
16+
declare void @llvm.masked.store.nxv1bf16.p0(<vscale x 1 x bfloat>, ptr, i32, <vscale x 1 x i1>)
417

518
define void @masked_store_nxv1f16(<vscale x 1 x half> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
619
; CHECK-LABEL: masked_store_nxv1f16:
@@ -35,6 +48,17 @@ define void @masked_store_nxv1f64(<vscale x 1 x double> %val, ptr %a, <vscale x
3548
}
3649
declare void @llvm.masked.store.nxv1f64.p0(<vscale x 1 x double>, ptr, i32, <vscale x 1 x i1>)
3750

51+
define void @masked_store_nxv2bf16(<vscale x 2 x bfloat> %val, ptr %a, <vscale x 2 x i1> %mask) nounwind {
52+
; CHECK-LABEL: masked_store_nxv2bf16:
53+
; CHECK: # %bb.0:
54+
; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, ma
55+
; CHECK-NEXT: vse16.v v8, (a0), v0.t
56+
; CHECK-NEXT: ret
57+
call void @llvm.masked.store.nxv2bf16.p0(<vscale x 2 x bfloat> %val, ptr %a, i32 2, <vscale x 2 x i1> %mask)
58+
ret void
59+
}
60+
declare void @llvm.masked.store.nxv2bf16.p0(<vscale x 2 x bfloat>, ptr, i32, <vscale x 2 x i1>)
61+
3862
define void @masked_store_nxv2f16(<vscale x 2 x half> %val, ptr %a, <vscale x 2 x i1> %mask) nounwind {
3963
; CHECK-LABEL: masked_store_nxv2f16:
4064
; CHECK: # %bb.0:
@@ -68,6 +92,17 @@ define void @masked_store_nxv2f64(<vscale x 2 x double> %val, ptr %a, <vscale x
6892
}
6993
declare void @llvm.masked.store.nxv2f64.p0(<vscale x 2 x double>, ptr, i32, <vscale x 2 x i1>)
7094

95+
define void @masked_store_nxv4bf16(<vscale x 4 x bfloat> %val, ptr %a, <vscale x 4 x i1> %mask) nounwind {
96+
; CHECK-LABEL: masked_store_nxv4bf16:
97+
; CHECK: # %bb.0:
98+
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
99+
; CHECK-NEXT: vse16.v v8, (a0), v0.t
100+
; CHECK-NEXT: ret
101+
call void @llvm.masked.store.nxv4bf16.p0(<vscale x 4 x bfloat> %val, ptr %a, i32 2, <vscale x 4 x i1> %mask)
102+
ret void
103+
}
104+
declare void @llvm.masked.store.nxv4bf16.p0(<vscale x 4 x bfloat>, ptr, i32, <vscale x 4 x i1>)
105+
71106
define void @masked_store_nxv4f16(<vscale x 4 x half> %val, ptr %a, <vscale x 4 x i1> %mask) nounwind {
72107
; CHECK-LABEL: masked_store_nxv4f16:
73108
; CHECK: # %bb.0:
@@ -101,6 +136,17 @@ define void @masked_store_nxv4f64(<vscale x 4 x double> %val, ptr %a, <vscale x
101136
}
102137
declare void @llvm.masked.store.nxv4f64.p0(<vscale x 4 x double>, ptr, i32, <vscale x 4 x i1>)
103138

139+
define void @masked_store_nxv8bf16(<vscale x 8 x bfloat> %val, ptr %a, <vscale x 8 x i1> %mask) nounwind {
140+
; CHECK-LABEL: masked_store_nxv8bf16:
141+
; CHECK: # %bb.0:
142+
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, ma
143+
; CHECK-NEXT: vse16.v v8, (a0), v0.t
144+
; CHECK-NEXT: ret
145+
call void @llvm.masked.store.nxv8bf16.p0(<vscale x 8 x bfloat> %val, ptr %a, i32 2, <vscale x 8 x i1> %mask)
146+
ret void
147+
}
148+
declare void @llvm.masked.store.nxv8bf16.p0(<vscale x 8 x bfloat>, ptr, i32, <vscale x 8 x i1>)
149+
104150
define void @masked_store_nxv8f16(<vscale x 8 x half> %val, ptr %a, <vscale x 8 x i1> %mask) nounwind {
105151
; CHECK-LABEL: masked_store_nxv8f16:
106152
; CHECK: # %bb.0:
@@ -134,6 +180,17 @@ define void @masked_store_nxv8f64(<vscale x 8 x double> %val, ptr %a, <vscale x
134180
}
135181
declare void @llvm.masked.store.nxv8f64.p0(<vscale x 8 x double>, ptr, i32, <vscale x 8 x i1>)
136182

183+
define void @masked_store_nxv16bf16(<vscale x 16 x bfloat> %val, ptr %a, <vscale x 16 x i1> %mask) nounwind {
184+
; CHECK-LABEL: masked_store_nxv16bf16:
185+
; CHECK: # %bb.0:
186+
; CHECK-NEXT: vsetvli a1, zero, e16, m4, ta, ma
187+
; CHECK-NEXT: vse16.v v8, (a0), v0.t
188+
; CHECK-NEXT: ret
189+
call void @llvm.masked.store.nxv16bf16.p0(<vscale x 16 x bfloat> %val, ptr %a, i32 2, <vscale x 16 x i1> %mask)
190+
ret void
191+
}
192+
declare void @llvm.masked.store.nxv16bf16.p0(<vscale x 16 x bfloat>, ptr, i32, <vscale x 16 x i1>)
193+
137194
define void @masked_store_nxv16f16(<vscale x 16 x half> %val, ptr %a, <vscale x 16 x i1> %mask) nounwind {
138195
; CHECK-LABEL: masked_store_nxv16f16:
139196
; CHECK: # %bb.0:
@@ -156,6 +213,17 @@ define void @masked_store_nxv16f32(<vscale x 16 x float> %val, ptr %a, <vscale x
156213
}
157214
declare void @llvm.masked.store.nxv16f32.p0(<vscale x 16 x float>, ptr, i32, <vscale x 16 x i1>)
158215

216+
define void @masked_store_nxv32bf16(<vscale x 32 x bfloat> %val, ptr %a, <vscale x 32 x i1> %mask) nounwind {
217+
; CHECK-LABEL: masked_store_nxv32bf16:
218+
; CHECK: # %bb.0:
219+
; CHECK-NEXT: vsetvli a1, zero, e16, m8, ta, ma
220+
; CHECK-NEXT: vse16.v v8, (a0), v0.t
221+
; CHECK-NEXT: ret
222+
call void @llvm.masked.store.nxv32bf16.p0(<vscale x 32 x bfloat> %val, ptr %a, i32 2, <vscale x 32 x i1> %mask)
223+
ret void
224+
}
225+
declare void @llvm.masked.store.nxv32bf16.p0(<vscale x 32 x bfloat>, ptr, i32, <vscale x 32 x i1>)
226+
159227
define void @masked_store_nxv32f16(<vscale x 32 x half> %val, ptr %a, <vscale x 32 x i1> %mask) nounwind {
160228
; CHECK-LABEL: masked_store_nxv32f16:
161229
; CHECK: # %bb.0:

0 commit comments

Comments
 (0)