Skip to content

Commit 480f07f

Browse files
authored
[RISCV] Add fixed length vector patterns for vfwmaccbf16.vv (#108204)
This adds VL patterns for vfwmaccbf16.vv so that we can handle fixed length vectors. It does this by teaching combineOp_VLToVWOp_VL to emit RISCVISD::VFWMADD_VL for bf16. The change in getOrCreateExtendedOp is needed because getNarrowType is based off of the bitwidth so returns f16. We need to explicitly check for bf16. Note that the .vf patterns don't work yet, since the build_vector splat gets lowered to a (vmv_v_x_vl (fmv_x_anyexth x)) instead of a vfmv.v.f, which SplatFP doesn't pick up, see #106637.
1 parent 3cd0137 commit 480f07f

File tree

3 files changed

+485
-4
lines changed

3 files changed

+485
-4
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14480,6 +14480,13 @@ struct NodeExtensionHelper {
1448014480
if (Source.getValueType() == NarrowVT)
1448114481
return Source;
1448214482

14483+
// vfmadd_vl -> vfwmadd_vl can take bf16 operands
14484+
if (Source.getValueType().getVectorElementType() == MVT::bf16) {
14485+
assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
14486+
Root->getOpcode() == RISCVISD::VFMADD_VL);
14487+
return Source;
14488+
}
14489+
1448314490
unsigned ExtOpc = getExtOpc(*SupportsExt);
1448414491

1448514492
// If we need an extension, we should be changing the type.
@@ -15731,7 +15738,7 @@ static SDValue performVFMADD_VLCombine(SDNode *N,
1573115738
return V;
1573215739

1573315740
if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
15734-
!Subtarget.hasVInstructionsF16())
15741+
!Subtarget.hasVInstructionsF16() && !Subtarget.hasStdExtZvfbfwma())
1573515742
return SDValue();
1573615743

1573715744
// FIXME: Ignore strict opcodes for now.

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,13 +2009,18 @@ multiclass VPatWidenFPMulAccVL_VV_VF<SDNode vop, string instruction_name> {
20092009
}
20102010
}
20112011

2012-
multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name> {
2013-
foreach vtiToWti = AllWidenableFloatVectors in {
2012+
multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name,
2013+
list<VTypeInfoToWide> vtiToWtis =
2014+
AllWidenableFloatVectors> {
2015+
foreach vtiToWti = vtiToWtis in {
20142016
defvar vti = vtiToWti.Vti;
20152017
defvar wti = vtiToWti.Wti;
20162018
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
20172019
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
2018-
GetVTypePredicates<wti>.Predicates) in {
2020+
GetVTypePredicates<wti>.Predicates,
2021+
!if(!eq(vti.Scalar, bf16),
2022+
[HasStdExtZvfbfwma],
2023+
[])) in {
20192024
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
20202025
(vti.Vector vti.RegClass:$rs2),
20212026
(wti.Vector wti.RegClass:$rd), (vti.Mask V0),
@@ -2451,6 +2456,8 @@ defm : VPatFPMulAccVL_VV_VF_RM<riscv_vfnmsub_vl_oneuse, "PseudoVFNMSAC">;
24512456

24522457
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
24532458
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACC">;
2459+
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACCBF16",
2460+
AllWidenableBFloatToFloatVectors>;
24542461
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmadd_vl, "PseudoVFWNMACC">;
24552462
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmsub_vl, "PseudoVFWMSAC">;
24562463
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmsub_vl, "PseudoVFWNMSAC">;

0 commit comments

Comments
 (0)