-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Add fixed length vector patterns for vfwmaccbf16.vv #108204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This adds VL patterns for vfwmaccbf16.vv so that we can handle fixed length vectors. It does this by teaching combineOp_VLToVWOp_VL to emit RISCVISD::VFWMADD_VL for bf16. The change in getOrCreateExtendedOp is needed because getNarrowType is based off of the bitwidth so returns f16. We need to explicitly check for bf16. Note that the .vf patterns don't work yet, since the build_vector pattern gets lowered to a vmv.v.x not a vfmv.v.f which SplatFP doesn't pick up, see llvm#106637.
@llvm/pr-subscribers-backend-risc-v Author: Luke Lau (lukel97) ChangesThis adds VL patterns for vfwmaccbf16.vv so that we can handle fixed length vectors. It does this by teaching combineOp_VLToVWOp_VL to emit RISCVISD::VFWMADD_VL for bf16. The change in getOrCreateExtendedOp is needed because getNarrowType is based off of the bitwidth so returns f16. We need to explicitly check for bf16. Note that the .vf patterns don't work yet, since the build_vector pattern gets lowered to a vmv.v.x not a vfmv.v.f which SplatFP doesn't pick up, see #106637. Patch is 23.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108204.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 23f2b0e96495e9..ddd460416d6290 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14454,6 +14454,13 @@ struct NodeExtensionHelper {
if (Source.getValueType() == NarrowVT)
return Source;
+ // vfmadd_vl -> vfwmadd_vl can take bf16 operands
+ if (Source.getValueType().getVectorElementType() == MVT::bf16) {
+ assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
+ Root->getOpcode() == RISCVISD::VFMADD_VL);
+ return Source;
+ }
+
unsigned ExtOpc = getExtOpc(*SupportsExt);
// If we need an extension, we should be changing the type.
@@ -15705,7 +15712,7 @@ static SDValue performVFMADD_VLCombine(SDNode *N,
return V;
if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
- !Subtarget.hasVInstructionsF16())
+ !Subtarget.hasVInstructionsF16() && !Subtarget.hasStdExtZvfbfwma())
return SDValue();
// FIXME: Ignore strict opcodes for now.
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 699536b1869692..9afbe567193607 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2009,13 +2009,18 @@ multiclass VPatWidenFPMulAccVL_VV_VF<SDNode vop, string instruction_name> {
}
}
-multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name> {
- foreach vtiToWti = AllWidenableFloatVectors in {
+multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name,
+ list<VTypeInfoToWide> vtiToWtis =
+ AllWidenableFloatVectors> {
+ foreach vtiToWti = vtiToWtis in {
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates) in {
+ GetVTypePredicates<wti>.Predicates,
+ !if(!eq(vti.Scalar, bf16),
+ [HasStdExtZvfbfwma],
+ [])) in {
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
(vti.Vector vti.RegClass:$rs2),
(wti.Vector wti.RegClass:$rd), (vti.Mask V0),
@@ -2451,6 +2456,8 @@ defm : VPatFPMulAccVL_VV_VF_RM<riscv_vfnmsub_vl_oneuse, "PseudoVFNMSAC">;
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACC">;
+defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACCBF16",
+ AllWidenableBFloatToFloatVectors>;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmadd_vl, "PseudoVFWNMACC">;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmsub_vl, "PseudoVFWMSAC">;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmsub_vl, "PseudoVFWNMSAC">;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
new file mode 100644
index 00000000000000..62a479bdedf649
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
@@ -0,0 +1,467 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
+; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefixes=ZVFBFMIN,ZVFBMIN32
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefixes=ZVFBFMIN,ZVFBMIN64
+
+define <1 x float> @vfwmaccbf16_vv_v1f32(<1 x float> %a, <1 x bfloat> %b, <1 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v1f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: addi sp, sp, -16
+; ZVFBFWMA-NEXT: .cfi_def_cfa_offset 16
+; ZVFBFWMA-NEXT: fcvt.s.bf16 fa5, fa0
+; ZVFBFWMA-NEXT: fsw fa5, 8(sp)
+; ZVFBFWMA-NEXT: addi a0, sp, 8
+; ZVFBFWMA-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; ZVFBFWMA-NEXT: vle32.v v9, (a0)
+; ZVFBFWMA-NEXT: fcvt.s.bf16 fa5, fa1
+; ZVFBFWMA-NEXT: fsw fa5, 12(sp)
+; ZVFBFWMA-NEXT: addi a0, sp, 12
+; ZVFBFWMA-NEXT: vle32.v v10, (a0)
+; ZVFBFWMA-NEXT: vfmacc.vv v8, v9, v10
+; ZVFBFWMA-NEXT: addi sp, sp, 16
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBMIN32-LABEL: vfwmaccbf16_vv_v1f32:
+; ZVFBMIN32: # %bb.0:
+; ZVFBMIN32-NEXT: addi sp, sp, -32
+; ZVFBMIN32-NEXT: .cfi_def_cfa_offset 32
+; ZVFBMIN32-NEXT: sw ra, 28(sp) # 4-byte Folded Spill
+; ZVFBMIN32-NEXT: sw s0, 24(sp) # 4-byte Folded Spill
+; ZVFBMIN32-NEXT: fsd fs0, 16(sp) # 8-byte Folded Spill
+; ZVFBMIN32-NEXT: .cfi_offset ra, -4
+; ZVFBMIN32-NEXT: .cfi_offset s0, -8
+; ZVFBMIN32-NEXT: .cfi_offset fs0, -16
+; ZVFBMIN32-NEXT: csrr a0, vlenb
+; ZVFBMIN32-NEXT: slli a0, a0, 1
+; ZVFBMIN32-NEXT: sub sp, sp, a0
+; ZVFBMIN32-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x20, 0x22, 0x11, 0x02, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 32 + 2 * vlenb
+; ZVFBMIN32-NEXT: fmv.s fs0, fa0
+; ZVFBMIN32-NEXT: addi a0, sp, 16
+; ZVFBMIN32-NEXT: vs1r.v v8, (a0) # Unknown-size Folded Spill
+; ZVFBMIN32-NEXT: fmv.s fa0, fa1
+; ZVFBMIN32-NEXT: call __truncsfbf2
+; ZVFBMIN32-NEXT: fmv.x.w s0, fa0
+; ZVFBMIN32-NEXT: fmv.s fa0, fs0
+; ZVFBMIN32-NEXT: call __truncsfbf2
+; ZVFBMIN32-NEXT: fmv.x.w a0, fa0
+; ZVFBMIN32-NEXT: slli a0, a0, 16
+; ZVFBMIN32-NEXT: sw a0, 8(sp)
+; ZVFBMIN32-NEXT: addi a0, sp, 8
+; ZVFBMIN32-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; ZVFBMIN32-NEXT: vle32.v v10, (a0)
+; ZVFBMIN32-NEXT: slli s0, s0, 16
+; ZVFBMIN32-NEXT: sw s0, 12(sp)
+; ZVFBMIN32-NEXT: addi a0, sp, 12
+; ZVFBMIN32-NEXT: vle32.v v9, (a0)
+; ZVFBMIN32-NEXT: addi a0, sp, 16
+; ZVFBMIN32-NEXT: vl1r.v v8, (a0) # Unknown-size Folded Reload
+; ZVFBMIN32-NEXT: vfmacc.vv v8, v10, v9
+; ZVFBMIN32-NEXT: csrr a0, vlenb
+; ZVFBMIN32-NEXT: slli a0, a0, 1
+; ZVFBMIN32-NEXT: add sp, sp, a0
+; ZVFBMIN32-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
+; ZVFBMIN32-NEXT: lw s0, 24(sp) # 4-byte Folded Reload
+; ZVFBMIN32-NEXT: fld fs0, 16(sp) # 8-byte Folded Reload
+; ZVFBMIN32-NEXT: addi sp, sp, 32
+; ZVFBMIN32-NEXT: ret
+;
+; ZVFBMIN64-LABEL: vfwmaccbf16_vv_v1f32:
+; ZVFBMIN64: # %bb.0:
+; ZVFBMIN64-NEXT: addi sp, sp, -64
+; ZVFBMIN64-NEXT: .cfi_def_cfa_offset 64
+; ZVFBMIN64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
+; ZVFBMIN64-NEXT: sd s0, 48(sp) # 8-byte Folded Spill
+; ZVFBMIN64-NEXT: fsd fs0, 40(sp) # 8-byte Folded Spill
+; ZVFBMIN64-NEXT: .cfi_offset ra, -8
+; ZVFBMIN64-NEXT: .cfi_offset s0, -16
+; ZVFBMIN64-NEXT: .cfi_offset fs0, -24
+; ZVFBMIN64-NEXT: csrr a0, vlenb
+; ZVFBMIN64-NEXT: slli a0, a0, 1
+; ZVFBMIN64-NEXT: sub sp, sp, a0
+; ZVFBMIN64-NEXT: .cfi_escape 0x0f, 0x0e, 0x72, 0x00, 0x11, 0xc0, 0x00, 0x22, 0x11, 0x02, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 64 + 2 * vlenb
+; ZVFBMIN64-NEXT: fmv.s fs0, fa0
+; ZVFBMIN64-NEXT: addi a0, sp, 32
+; ZVFBMIN64-NEXT: vs1r.v v8, (a0) # Unknown-size Folded Spill
+; ZVFBMIN64-NEXT: fmv.s fa0, fa1
+; ZVFBMIN64-NEXT: call __truncsfbf2
+; ZVFBMIN64-NEXT: fmv.x.w s0, fa0
+; ZVFBMIN64-NEXT: fmv.s fa0, fs0
+; ZVFBMIN64-NEXT: call __truncsfbf2
+; ZVFBMIN64-NEXT: fmv.x.w a0, fa0
+; ZVFBMIN64-NEXT: slli a0, a0, 16
+; ZVFBMIN64-NEXT: fmv.w.x fa5, a0
+; ZVFBMIN64-NEXT: fsw fa5, 16(sp)
+; ZVFBMIN64-NEXT: addi a0, sp, 16
+; ZVFBMIN64-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; ZVFBMIN64-NEXT: vle32.v v10, (a0)
+; ZVFBMIN64-NEXT: slli s0, s0, 16
+; ZVFBMIN64-NEXT: fmv.w.x fa5, s0
+; ZVFBMIN64-NEXT: fsw fa5, 20(sp)
+; ZVFBMIN64-NEXT: addi a0, sp, 20
+; ZVFBMIN64-NEXT: vle32.v v9, (a0)
+; ZVFBMIN64-NEXT: addi a0, sp, 32
+; ZVFBMIN64-NEXT: vl1r.v v8, (a0) # Unknown-size Folded Reload
+; ZVFBMIN64-NEXT: vfmacc.vv v8, v10, v9
+; ZVFBMIN64-NEXT: csrr a0, vlenb
+; ZVFBMIN64-NEXT: slli a0, a0, 1
+; ZVFBMIN64-NEXT: add sp, sp, a0
+; ZVFBMIN64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
+; ZVFBMIN64-NEXT: ld s0, 48(sp) # 8-byte Folded Reload
+; ZVFBMIN64-NEXT: fld fs0, 40(sp) # 8-byte Folded Reload
+; ZVFBMIN64-NEXT: addi sp, sp, 64
+; ZVFBMIN64-NEXT: ret
+ %b.ext = fpext <1 x bfloat> %b to <1 x float>
+ %c.ext = fpext <1 x bfloat> %c to <1 x float>
+ %res = call <1 x float> @llvm.fma.v1f32(<1 x float> %b.ext, <1 x float> %c.ext, <1 x float> %a)
+ ret <1 x float> %res
+}
+
+define <1 x float> @vfwmaccbf16_vf_v1f32(<1 x float> %a, bfloat %b, <1 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v1f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: addi sp, sp, -16
+; ZVFBFWMA-NEXT: .cfi_def_cfa_offset 16
+; ZVFBFWMA-NEXT: fcvt.s.bf16 fa5, fa0
+; ZVFBFWMA-NEXT: fsw fa5, 8(sp)
+; ZVFBFWMA-NEXT: addi a0, sp, 8
+; ZVFBFWMA-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; ZVFBFWMA-NEXT: vle32.v v9, (a0)
+; ZVFBFWMA-NEXT: fcvt.s.bf16 fa5, fa1
+; ZVFBFWMA-NEXT: fsw fa5, 12(sp)
+; ZVFBFWMA-NEXT: addi a0, sp, 12
+; ZVFBFWMA-NEXT: vle32.v v10, (a0)
+; ZVFBFWMA-NEXT: vfmacc.vv v8, v9, v10
+; ZVFBFWMA-NEXT: addi sp, sp, 16
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBMIN32-LABEL: vfwmaccbf16_vf_v1f32:
+; ZVFBMIN32: # %bb.0:
+; ZVFBMIN32-NEXT: addi sp, sp, -48
+; ZVFBMIN32-NEXT: .cfi_def_cfa_offset 48
+; ZVFBMIN32-NEXT: sw ra, 44(sp) # 4-byte Folded Spill
+; ZVFBMIN32-NEXT: fsd fs0, 32(sp) # 8-byte Folded Spill
+; ZVFBMIN32-NEXT: .cfi_offset ra, -4
+; ZVFBMIN32-NEXT: .cfi_offset fs0, -16
+; ZVFBMIN32-NEXT: csrr a0, vlenb
+; ZVFBMIN32-NEXT: slli a0, a0, 1
+; ZVFBMIN32-NEXT: sub sp, sp, a0
+; ZVFBMIN32-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x30, 0x22, 0x11, 0x02, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 48 + 2 * vlenb
+; ZVFBMIN32-NEXT: fmv.s fs0, fa0
+; ZVFBMIN32-NEXT: addi a0, sp, 32
+; ZVFBMIN32-NEXT: vs1r.v v8, (a0) # Unknown-size Folded Spill
+; ZVFBMIN32-NEXT: fmv.s fa0, fa1
+; ZVFBMIN32-NEXT: call __truncsfbf2
+; ZVFBMIN32-NEXT: fmv.x.w a0, fa0
+; ZVFBMIN32-NEXT: fmv.x.w a1, fs0
+; ZVFBMIN32-NEXT: slli a1, a1, 16
+; ZVFBMIN32-NEXT: sw a1, 8(sp)
+; ZVFBMIN32-NEXT: addi a1, sp, 8
+; ZVFBMIN32-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; ZVFBMIN32-NEXT: vle32.v v10, (a1)
+; ZVFBMIN32-NEXT: slli a0, a0, 16
+; ZVFBMIN32-NEXT: sw a0, 12(sp)
+; ZVFBMIN32-NEXT: addi a0, sp, 12
+; ZVFBMIN32-NEXT: vle32.v v9, (a0)
+; ZVFBMIN32-NEXT: addi a0, sp, 32
+; ZVFBMIN32-NEXT: vl1r.v v8, (a0) # Unknown-size Folded Reload
+; ZVFBMIN32-NEXT: vfmacc.vv v8, v10, v9
+; ZVFBMIN32-NEXT: csrr a0, vlenb
+; ZVFBMIN32-NEXT: slli a0, a0, 1
+; ZVFBMIN32-NEXT: add sp, sp, a0
+; ZVFBMIN32-NEXT: lw ra, 44(sp) # 4-byte Folded Reload
+; ZVFBMIN32-NEXT: fld fs0, 32(sp) # 8-byte Folded Reload
+; ZVFBMIN32-NEXT: addi sp, sp, 48
+; ZVFBMIN32-NEXT: ret
+;
+; ZVFBMIN64-LABEL: vfwmaccbf16_vf_v1f32:
+; ZVFBMIN64: # %bb.0:
+; ZVFBMIN64-NEXT: addi sp, sp, -48
+; ZVFBMIN64-NEXT: .cfi_def_cfa_offset 48
+; ZVFBMIN64-NEXT: sd ra, 40(sp) # 8-byte Folded Spill
+; ZVFBMIN64-NEXT: fsd fs0, 32(sp) # 8-byte Folded Spill
+; ZVFBMIN64-NEXT: .cfi_offset ra, -8
+; ZVFBMIN64-NEXT: .cfi_offset fs0, -16
+; ZVFBMIN64-NEXT: csrr a0, vlenb
+; ZVFBMIN64-NEXT: slli a0, a0, 1
+; ZVFBMIN64-NEXT: sub sp, sp, a0
+; ZVFBMIN64-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x30, 0x22, 0x11, 0x02, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 48 + 2 * vlenb
+; ZVFBMIN64-NEXT: fmv.s fs0, fa0
+; ZVFBMIN64-NEXT: addi a0, sp, 32
+; ZVFBMIN64-NEXT: vs1r.v v8, (a0) # Unknown-size Folded Spill
+; ZVFBMIN64-NEXT: fmv.s fa0, fa1
+; ZVFBMIN64-NEXT: call __truncsfbf2
+; ZVFBMIN64-NEXT: fmv.x.w a0, fa0
+; ZVFBMIN64-NEXT: fmv.x.w a1, fs0
+; ZVFBMIN64-NEXT: slli a1, a1, 16
+; ZVFBMIN64-NEXT: fmv.w.x fa5, a1
+; ZVFBMIN64-NEXT: fsw fa5, 24(sp)
+; ZVFBMIN64-NEXT: addi a1, sp, 24
+; ZVFBMIN64-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; ZVFBMIN64-NEXT: vle32.v v10, (a1)
+; ZVFBMIN64-NEXT: slli a0, a0, 16
+; ZVFBMIN64-NEXT: fmv.w.x fa5, a0
+; ZVFBMIN64-NEXT: fsw fa5, 28(sp)
+; ZVFBMIN64-NEXT: addi a0, sp, 28
+; ZVFBMIN64-NEXT: vle32.v v9, (a0)
+; ZVFBMIN64-NEXT: addi a0, sp, 32
+; ZVFBMIN64-NEXT: vl1r.v v8, (a0) # Unknown-size Folded Reload
+; ZVFBMIN64-NEXT: vfmacc.vv v8, v10, v9
+; ZVFBMIN64-NEXT: csrr a0, vlenb
+; ZVFBMIN64-NEXT: slli a0, a0, 1
+; ZVFBMIN64-NEXT: add sp, sp, a0
+; ZVFBMIN64-NEXT: ld ra, 40(sp) # 8-byte Folded Reload
+; ZVFBMIN64-NEXT: fld fs0, 32(sp) # 8-byte Folded Reload
+; ZVFBMIN64-NEXT: addi sp, sp, 48
+; ZVFBMIN64-NEXT: ret
+ %b.head = insertelement <1 x bfloat> poison, bfloat %b, i32 0
+ %b.splat = shufflevector <1 x bfloat> %b.head, <1 x bfloat> poison, <1 x i32> zeroinitializer
+ %b.ext = fpext <1 x bfloat> %b.splat to <1 x float>
+ %c.ext = fpext <1 x bfloat> %c to <1 x float>
+ %res = call <1 x float> @llvm.fma.v1f32(<1 x float> %b.ext, <1 x float> %c.ext, <1 x float> %a)
+ ret <1 x float> %res
+}
+
+define <2 x float> @vfwmaccbf16_vv_v2f32(<2 x float> %a, <2 x bfloat> %b, <2 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v2f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; ZVFBFWMA-NEXT: vfwmaccbf16.vv v8, v9, v10
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vv_v2f32:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v11, v9
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v9, v10
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v9
+; ZVFBFMIN-NEXT: ret
+ %b.ext = fpext <2 x bfloat> %b to <2 x float>
+ %c.ext = fpext <2 x bfloat> %c to <2 x float>
+ %res = call <2 x float> @llvm.fma.v2f32(<2 x float> %b.ext, <2 x float> %c.ext, <2 x float> %a)
+ ret <2 x float> %res
+}
+
+define <2 x float> @vfwmaccbf16_vf_v2f32(<2 x float> %a, bfloat %b, <2 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v2f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: fmv.x.h a0, fa0
+; ZVFBFWMA-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; ZVFBFWMA-NEXT: vmv.v.x v10, a0
+; ZVFBFWMA-NEXT: vfwmaccbf16.vv v8, v10, v9
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v2f32:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: fmv.x.w a0, fa0
+; ZVFBFMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; ZVFBFMIN-NEXT: vmv.v.x v10, a0
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v11, v10
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v10, v9
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v10
+; ZVFBFMIN-NEXT: ret
+ %b.head = insertelement <2 x bfloat> poison, bfloat %b, i32 0
+ %b.splat = shufflevector <2 x bfloat> %b.head, <2 x bfloat> poison, <2 x i32> zeroinitializer
+ %b.ext = fpext <2 x bfloat> %b.splat to <2 x float>
+ %c.ext = fpext <2 x bfloat> %c to <2 x float>
+ %res = call <2 x float> @llvm.fma.v2f32(<2 x float> %b.ext, <2 x float> %c.ext, <2 x float> %a)
+ ret <2 x float> %res
+}
+
+define <4 x float> @vfwmaccbf16_vv_v4f32(<4 x float> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v4f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFWMA-NEXT: vfwmaccbf16.vv v8, v9, v10
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vv_v4f32:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v11, v9
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v9, v10
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v9
+; ZVFBFMIN-NEXT: ret
+ %b.ext = fpext <4 x bfloat> %b to <4 x float>
+ %c.ext = fpext <4 x bfloat> %c to <4 x float>
+ %res = call <4 x float> @llvm.fma.v4f32(<4 x float> %b.ext, <4 x float> %c.ext, <4 x float> %a)
+ ret <4 x float> %res
+}
+
+define <4 x float> @vfwmaccbf16_vf_v4f32(<4 x float> %a, bfloat %b, <4 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v4f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: fmv.x.h a0, fa0
+; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFWMA-NEXT: vmv.v.x v10, a0
+; ZVFBFWMA-NEXT: vfwmaccbf16.vv v8, v10, v9
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v4f32:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: fmv.x.w a0, fa0
+; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFMIN-NEXT: vmv.v.x v10, a0
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v11, v10
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v10, v9
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v10
+; ZVFBFMIN-NEXT: ret
+ %b.head = insertelement <4 x bfloat> poison, bfloat %b, i32 0
+ %b.splat = shufflevector <4 x bfloat> %b.head, <4 x bfloat> poison, <4 x i32> zeroinitializer
+ %b.ext = fpext <4 x bfloat> %b.splat to <4 x float>
+ %c.ext = fpext <4 x bfloat> %c to <4 x float>
+ %res = call <4 x float> @llvm.fma.v4f32(<4 x float> %b.ext, <4 x float> %c.ext, <4 x float> %a)
+ ret <4 x float> %res
+}
+
+define <8 x float> @vfwmaccbf16_vv_v8f32(<8 x float> %a, <8 x bfloat> %b, <8 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v8f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: vsetivli zero, 8, e16, m1, ta, ma
+; ZVFBFWMA-NEXT: vfwmaccbf16.vv v8, v10, v11
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vv_v8f32:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v12, v10
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v14, v11
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m2, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vv v8, v12, v14
+; ZVFBFMIN-NEXT: ret
+ %b.ext = fpext <8 x bfloat> %b to <8 x float>
+ %c.ext = fpext <8 x bfloat> %c to <8 x float>
+ %res = call <8 x float> @llvm.fma.v8f32(<8 x float> %b.ext, <8 x float> %c.ext, <8 x float> %a)
+ ret <8 x float> %res
+}
+
+define <8 x float> @vfwmaccbf16_vf_v8f32(<8 x float> %a, bfloat %b, <8 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v8f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: fmv.x.h a0, fa0
+; ZVFBFWMA-NEXT: vsetivli zero, 8, e16, m1, ta, ma
+; ZVFBFWMA-NEXT: vmv.v.x v11, a0
+; ZVFBFWMA-NEXT: vfwmaccbf16.vv v8, v11, v10
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v8f32:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: fmv.x.w a0, fa0
+; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
+; ZVFBFMIN-NEXT: vmv.v.x v11, a0
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v12, v11
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v14, v10
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m2, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vv v8, v12, v14
+; ZVFBFMIN-NEXT: ret
+ %b.head = insertelement <8 x bfloat> poison, bfloat %b, i32 0
+ %b.splat = shufflevector <8 x bfloat> %b.head, <8 x bfloat> poison, <8 x i32> zeroinitializer
+ %b.ext = fpext <8 x bfloat> %b.splat to <8 x float>
+ %c.ext = fpext <8 x bfloat> %c to <8 x float>
+ %res = call <8 x float> @llvm.fma.v8f32(<8 x float> %b.ext, <8 x float> %c.ext, <8 x float> %a)
+ ret <8 x float> %res
+}
+
+define <16 x float> @vfwmaccbf16_vv_v16f32(<16 x float> %a, <16 x bfloat> %b, <16 x bfloat> %c) {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v16f32:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; ZVFBFWMA-NEXT: vfwmaccbf16.vv v8, v12, v14
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vv_v16f32:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v16, v12
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v20, v14
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vv v8, v16, v20
+; ZVFBFMIN-NEXT: ret
+ %b.ext = fpext <16 x bfloat> %b to <16 x float>
+ %c.ext = fpext <16 x bfloat> %c to <16 x float>
+ %res = call <16 x float> @llv...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This adds VL patterns for vfwmaccbf16.vv so that we can handle fixed length vectors.
It does this by teaching combineOp_VLToVWOp_VL to emit RISCVISD::VFWMADD_VL for bf16. The change in getOrCreateExtendedOp is needed because getNarrowType is based off of the bitwidth so returns f16. We need to explicitly check for bf16.
Note that the .vf patterns don't work yet, since the build_vector splat gets lowered to a (vmv_v_x_vl (fmv_x_anyexth x)) instead of a vfmv.v.f, which SplatFP doesn't pick up, see #106637.