Skip to content

Commit 6af2f22

Browse files
authored
[RISCV] Restrict combineOp_VLToVWOp_VL w/ bf16 to vfwmadd_vl with zvfbfwma (#108798)
We currently make sure to check that if folding an op to an f16 widening op that we have zvfh. We need to do the same for bf16 vectors, but with the further restriction that we can only combine vfmadd_vl to vfwmadd_vl (to get vfwmaccbf16.v{v,f}). The added test case currently crashes because we try to fold an add to a bf16 widening add, which doesn't exist in zvfbfmin or zvfbfwma This moves the checks into the extension support checks to keep it one place.
1 parent 884ff9e commit 6af2f22

File tree

2 files changed

+64
-12
lines changed

2 files changed

+64
-12
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14740,6 +14740,19 @@ struct NodeExtensionHelper {
1474014740
EnforceOneUse = false;
1474114741
}
1474214742

14743+
bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
14744+
const RISCVSubtarget &Subtarget) {
14745+
// Any f16 extension will neeed zvfh
14746+
if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
14747+
return false;
14748+
// The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
14749+
// zvfbfwma
14750+
if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
14751+
Root->getOpcode() != RISCVISD::VFMADD_VL))
14752+
return false;
14753+
return true;
14754+
}
14755+
1474314756
/// Helper method to set the various fields of this struct based on the
1474414757
/// type of \p Root.
1474514758
void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
@@ -14775,9 +14788,14 @@ struct NodeExtensionHelper {
1477514788
case RISCVISD::VSEXT_VL:
1477614789
SupportsSExt = true;
1477714790
break;
14778-
case RISCVISD::FP_EXTEND_VL:
14791+
case RISCVISD::FP_EXTEND_VL: {
14792+
MVT NarrowEltVT =
14793+
OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
14794+
if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
14795+
break;
1477914796
SupportsFPExt = true;
1478014797
break;
14798+
}
1478114799
case ISD::SPLAT_VECTOR:
1478214800
case RISCVISD::VMV_V_X_VL:
1478314801
fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
@@ -14792,6 +14810,10 @@ struct NodeExtensionHelper {
1479214810
if (Op.getOpcode() != ISD::FP_EXTEND)
1479314811
break;
1479414812

14813+
if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
14814+
Subtarget))
14815+
break;
14816+
1479514817
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1479614818
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
1479714819
if (NarrowSize != ScalarBits)
@@ -15774,10 +15796,6 @@ static SDValue performVFMADD_VLCombine(SDNode *N,
1577415796
if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
1577515797
return V;
1577615798

15777-
if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
15778-
!Subtarget.hasVInstructionsF16() && !Subtarget.hasStdExtZvfbfwma())
15779-
return SDValue();
15780-
1578115799
// FIXME: Ignore strict opcodes for now.
1578215800
if (N->isTargetStrictFPOpcode())
1578315801
return SDValue();
@@ -17522,12 +17540,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1752217540
case RISCVISD::FSUB_VL:
1752317541
case RISCVISD::FMUL_VL:
1752417542
case RISCVISD::VFWADD_W_VL:
17525-
case RISCVISD::VFWSUB_W_VL: {
17526-
if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
17527-
!Subtarget.hasVInstructionsF16())
17528-
return SDValue();
17543+
case RISCVISD::VFWSUB_W_VL:
1752917544
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
17530-
}
1753117545
case ISD::LOAD:
1753217546
case ISD::STORE: {
1753317547
if (DCI.isAfterLegalizeDAG())

llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,44 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh | FileCheck %s --check-prefixes=ZVFH
3-
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfhmin | FileCheck %s --check-prefixes=ZVFHMIN
2+
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFH
3+
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfhmin,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFHMIN
4+
5+
define <vscale x 2 x float> @vfwadd_same_operand_nxv2bf16(<vscale x 2 x bfloat> %arg, i32 signext %vl) {
6+
; CHECK-LABEL: vfwadd_same_operand_nxv2bf16:
7+
; CHECK: # %bb.0: # %bb
8+
; CHECK-NEXT: slli a0, a0, 32
9+
; CHECK-NEXT: srli a0, a0, 32
10+
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
11+
; CHECK-NEXT: vfwcvtbf16.f.f.v v9, v8
12+
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
13+
; CHECK-NEXT: vfadd.vv v8, v9, v9
14+
; CHECK-NEXT: ret
15+
bb:
16+
%tmp = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2bf16(<vscale x 2 x bfloat> %arg, <vscale x 2 x i1> splat (i1 true), i32 %vl)
17+
%tmp2 = call <vscale x 2 x float> @llvm.vp.fadd.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x i1> splat (i1 true), i32 %vl)
18+
ret <vscale x 2 x float> %tmp2
19+
}
20+
21+
; Make sure we don't widen vfmadd.vv -> vfwmaccvbf16.vv if there's other
22+
; unwidenable uses
23+
define <vscale x 2 x float> @vfwadd_same_operand_nxv2bf16_multiuse(<vscale x 2 x bfloat> %arg, <vscale x 2 x float> %acc, i32 signext %vl, ptr %p) {
24+
; CHECK-LABEL: vfwadd_same_operand_nxv2bf16_multiuse:
25+
; CHECK: # %bb.0: # %bb
26+
; CHECK-NEXT: slli a0, a0, 32
27+
; CHECK-NEXT: srli a0, a0, 32
28+
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
29+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
30+
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
31+
; CHECK-NEXT: vfadd.vv v8, v10, v10
32+
; CHECK-NEXT: vfmadd.vv v10, v10, v9
33+
; CHECK-NEXT: vs1r.v v10, (a1)
34+
; CHECK-NEXT: ret
35+
bb:
36+
%tmp = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2bf16(<vscale x 2 x bfloat> %arg, <vscale x 2 x i1> splat (i1 true), i32 %vl)
37+
%tmp2 = call <vscale x 2 x float> @llvm.vp.fadd.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x i1> splat (i1 true), i32 %vl)
38+
%tmp3 = call <vscale x 2 x float> @llvm.vp.fma.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x float> %acc, <vscale x 2 x i1> splat (i1 true), i32 %vl)
39+
store <vscale x 2 x float> %tmp3, ptr %p
40+
ret <vscale x 2 x float> %tmp2
41+
}
442

543
define <vscale x 2 x float> @vfwadd_same_operand(<vscale x 2 x half> %arg, i32 signext %vl) {
644
; ZVFH-LABEL: vfwadd_same_operand:

0 commit comments

Comments
 (0)