@@ -215,6 +215,15 @@ def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
215
215
SDTCVecEltisVT<2, i1>,
216
216
SDTCisVT<3, XLenVT>]>>;
217
217
218
+ def SDT_RISCVVWMUL_VL : SDTypeProfile<1, 4, [SDTCisVec<0>,
219
+ SDTCisSameNumEltsAs<0, 1>,
220
+ SDTCisSameAs<1, 2>,
221
+ SDTCisSameNumEltsAs<1, 3>,
222
+ SDTCVecEltisVT<3, i1>,
223
+ SDTCisVT<4, XLenVT>]>;
224
+ def riscv_vwmul_vl : SDNode<"RISCVISD::VWMUL_VL", SDT_RISCVVWMUL_VL, [SDNPCommutative]>;
225
+ def riscv_vwmulu_vl : SDNode<"RISCVISD::VWMULU_VL", SDT_RISCVVWMUL_VL, [SDNPCommutative]>;
226
+
218
227
def SDTRVVVecReduce : SDTypeProfile<1, 4, [
219
228
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameAs<0, 2>, SDTCVecEltisVT<3, i1>,
220
229
SDTCisSameNumEltsAs<1, 3>, SDTCisVT<4, XLenVT>
@@ -226,6 +235,18 @@ def riscv_mul_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D),
226
235
return N->hasOneUse();
227
236
}]>;
228
237
238
+ def riscv_vwmul_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D),
239
+ (riscv_vwmul_vl node:$A, node:$B, node:$C,
240
+ node:$D), [{
241
+ return N->hasOneUse();
242
+ }]>;
243
+
244
+ def riscv_vwmulu_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D),
245
+ (riscv_vwmulu_vl node:$A, node:$B, node:$C,
246
+ node:$D), [{
247
+ return N->hasOneUse();
248
+ }]>;
249
+
229
250
foreach kind = ["ADD", "UMAX", "SMAX", "UMIN", "SMIN", "AND", "OR", "XOR",
230
251
"FADD", "SEQ_FADD", "FMIN", "FMAX"] in
231
252
def rvv_vecreduce_#kind#_vl : SDNode<"RISCVISD::VECREDUCE_"#kind#"_VL", SDTRVVVecReduce>;
@@ -326,6 +347,20 @@ multiclass VPatBinaryVL_VV_VX_VI<SDNode vop, string instruction_name,
326
347
}
327
348
}
328
349
350
+ multiclass VPatBinaryWVL_VV_VX<SDNode vop, string instruction_name> {
351
+ foreach VtiToWti = AllWidenableIntVectors in {
352
+ defvar vti = VtiToWti.Vti;
353
+ defvar wti = VtiToWti.Wti;
354
+ defm : VPatBinaryVL_VV<vop, instruction_name,
355
+ wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
356
+ vti.LMul, wti.RegClass, vti.RegClass>;
357
+ defm : VPatBinaryVL_XI<vop, instruction_name, "VX",
358
+ wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
359
+ vti.LMul, wti.RegClass, vti.RegClass,
360
+ SplatPat, GPR>;
361
+ }
362
+ }
363
+
329
364
class VPatBinaryVL_VF<SDNode vop,
330
365
string instruction_name,
331
366
ValueType result_type,
@@ -737,6 +772,10 @@ defm : VPatBinaryVL_VV_VX<riscv_sdiv_vl, "PseudoVDIV">;
737
772
defm : VPatBinaryVL_VV_VX<riscv_urem_vl, "PseudoVREMU">;
738
773
defm : VPatBinaryVL_VV_VX<riscv_srem_vl, "PseudoVREM">;
739
774
775
+ // 12.12. Vector Widening Integer Multiply Instructions
776
+ defm : VPatBinaryWVL_VV_VX<riscv_vwmul_vl, "PseudoVWMUL">;
777
+ defm : VPatBinaryWVL_VV_VX<riscv_vwmulu_vl, "PseudoVWMULU">;
778
+
740
779
// 12.13 Vector Single-Width Integer Multiply-Add Instructions
741
780
foreach vti = AllIntegerVectors in {
742
781
// NOTE: We choose VMADD because it has the most commuting freedom. So it
@@ -784,6 +823,49 @@ foreach vti = AllIntegerVectors in {
784
823
GPR:$vl, vti.Log2SEW)>;
785
824
}
786
825
826
+ // 12.14. Vector Widening Integer Multiply-Add Instructions
827
+ foreach vtiTowti = AllWidenableIntVectors in {
828
+ defvar vti = vtiTowti.Vti;
829
+ defvar wti = vtiTowti.Wti;
830
+ def : Pat<(wti.Vector
831
+ (riscv_add_vl wti.RegClass:$rd,
832
+ (riscv_vwmul_vl_oneuse vti.RegClass:$rs1,
833
+ (vti.Vector vti.RegClass:$rs2),
834
+ (vti.Mask true_mask), VLOpFrag),
835
+ (vti.Mask true_mask), VLOpFrag)),
836
+ (!cast<Instruction>("PseudoVWMACC_VV_"# vti.LMul.MX)
837
+ wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
838
+ GPR:$vl, vti.Log2SEW)>;
839
+ def : Pat<(wti.Vector
840
+ (riscv_add_vl wti.RegClass:$rd,
841
+ (riscv_vwmulu_vl_oneuse vti.RegClass:$rs1,
842
+ (vti.Vector vti.RegClass:$rs2),
843
+ (vti.Mask true_mask), VLOpFrag),
844
+ (vti.Mask true_mask), VLOpFrag)),
845
+ (!cast<Instruction>("PseudoVWMACCU_VV_"# vti.LMul.MX)
846
+ wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
847
+ GPR:$vl, vti.Log2SEW)>;
848
+
849
+ def : Pat<(wti.Vector
850
+ (riscv_add_vl wti.RegClass:$rd,
851
+ (riscv_vwmul_vl_oneuse (SplatPat XLenVT:$rs1),
852
+ (vti.Vector vti.RegClass:$rs2),
853
+ (vti.Mask true_mask), VLOpFrag),
854
+ (vti.Mask true_mask), VLOpFrag)),
855
+ (!cast<Instruction>("PseudoVWMACC_VX_" # vti.LMul.MX)
856
+ wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
857
+ GPR:$vl, vti.Log2SEW)>;
858
+ def : Pat<(wti.Vector
859
+ (riscv_add_vl wti.RegClass:$rd,
860
+ (riscv_vwmulu_vl_oneuse (SplatPat XLenVT:$rs1),
861
+ (vti.Vector vti.RegClass:$rs2),
862
+ (vti.Mask true_mask), VLOpFrag),
863
+ (vti.Mask true_mask), VLOpFrag)),
864
+ (!cast<Instruction>("PseudoVWMACCU_VX_" # vti.LMul.MX)
865
+ wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
866
+ GPR:$vl, vti.Log2SEW)>;
867
+ }
868
+
787
869
// 12.15. Vector Integer Merge Instructions
788
870
foreach vti = AllIntegerVectors in {
789
871
def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask VMV0:$vm),
0 commit comments