Skip to content

Commit e0e8edf

Browse files
committed
[RISCV] Add isel patterns for masked RISCVISD::FMA_VL with RISCVISD::FNEG_VL.
This helps us form vfnmsub, vfnmadd, and vfmusb from masked VP intrinsics. I've used "srcvalue" for the mask parameter in the fneg nodes. We can't match "V0" because that doesn't ensure the mask the is the same. Instead it matches two different nodes and generates two copies to V0 of those separate values. Reviewed By: rogfer01 Differential Revision: https://reviews.llvm.org/D120287
1 parent 3491f2f commit e0e8edf

File tree

2 files changed

+6410
-10
lines changed

2 files changed

+6410
-10
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 96 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,29 +1165,59 @@ foreach vti = AllFloatVectors in {
11651165
(!cast<Instruction>("PseudoVFMSUB_VV_"# suffix)
11661166
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
11671167
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1168+
def : Pat<(vti.Vector (riscv_fma_vl vti.RegClass:$rs1, vti.RegClass:$rd,
1169+
(riscv_fneg_vl vti.RegClass:$rs2,
1170+
(vti.Mask srcvalue),
1171+
VLOpFrag),
1172+
(vti.Mask V0),
1173+
VLOpFrag)),
1174+
(!cast<Instruction>("PseudoVFMSUB_VV_"# suffix #"_MASK")
1175+
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
1176+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
11681177

11691178
def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl vti.RegClass:$rs1,
1170-
(vti.Mask true_mask),
1179+
(vti.Mask srcvalue),
11711180
VLOpFrag),
11721181
vti.RegClass:$rd,
11731182
(riscv_fneg_vl vti.RegClass:$rs2,
1174-
(vti.Mask true_mask),
1183+
(vti.Mask srcvalue),
11751184
VLOpFrag),
11761185
(vti.Mask true_mask),
11771186
VLOpFrag)),
11781187
(!cast<Instruction>("PseudoVFNMADD_VV_"# suffix)
11791188
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
11801189
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1190+
def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl vti.RegClass:$rs1,
1191+
(vti.Mask srcvalue),
1192+
VLOpFrag),
1193+
vti.RegClass:$rd,
1194+
(riscv_fneg_vl vti.RegClass:$rs2,
1195+
(vti.Mask srcvalue),
1196+
VLOpFrag),
1197+
(vti.Mask V0),
1198+
VLOpFrag)),
1199+
(!cast<Instruction>("PseudoVFNMADD_VV_"# suffix #"_MASK")
1200+
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
1201+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
11811202

11821203
def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl vti.RegClass:$rs1,
1183-
(vti.Mask true_mask),
1204+
(vti.Mask srcvalue),
11841205
VLOpFrag),
11851206
vti.RegClass:$rd, vti.RegClass:$rs2,
11861207
(vti.Mask true_mask),
11871208
VLOpFrag)),
11881209
(!cast<Instruction>("PseudoVFNMSUB_VV_"# suffix)
11891210
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
11901211
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1212+
def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl vti.RegClass:$rs1,
1213+
(vti.Mask srcvalue),
1214+
VLOpFrag),
1215+
vti.RegClass:$rd, vti.RegClass:$rs2,
1216+
(vti.Mask V0),
1217+
VLOpFrag)),
1218+
(!cast<Instruction>("PseudoVFNMSUB_VV_"# suffix #"_MASK")
1219+
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
1220+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
11911221

11921222
// The choice of VFMADD here is arbitrary, vfmadd.vf and vfmacc.vf are equally
11931223
// commutable.
@@ -1209,19 +1239,30 @@ foreach vti = AllFloatVectors in {
12091239
def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
12101240
vti.RegClass:$rd,
12111241
(riscv_fneg_vl vti.RegClass:$rs2,
1212-
(vti.Mask true_mask),
1242+
(vti.Mask srcvalue),
12131243
VLOpFrag),
12141244
(vti.Mask true_mask),
12151245
VLOpFrag)),
12161246
(!cast<Instruction>("PseudoVFMSUB_V" # vti.ScalarSuffix # "_" # suffix)
12171247
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
12181248
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1249+
def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
1250+
vti.RegClass:$rd,
1251+
(riscv_fneg_vl vti.RegClass:$rs2,
1252+
(vti.Mask srcvalue),
1253+
VLOpFrag),
1254+
(vti.Mask V0),
1255+
VLOpFrag)),
1256+
(!cast<Instruction>("PseudoVFMSUB_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
1257+
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
1258+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1259+
12191260
def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
12201261
(riscv_fneg_vl vti.RegClass:$rd,
1221-
(vti.Mask true_mask),
1262+
(vti.Mask srcvalue),
12221263
VLOpFrag),
12231264
(riscv_fneg_vl vti.RegClass:$rs2,
1224-
(vti.Mask true_mask),
1265+
(vti.Mask srcvalue),
12251266
VLOpFrag),
12261267
(vti.Mask true_mask),
12271268
VLOpFrag)),
@@ -1230,37 +1271,82 @@ foreach vti = AllFloatVectors in {
12301271
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
12311272
def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
12321273
(riscv_fneg_vl vti.RegClass:$rd,
1233-
(vti.Mask true_mask),
1274+
(vti.Mask srcvalue),
1275+
VLOpFrag),
1276+
(riscv_fneg_vl vti.RegClass:$rs2,
1277+
(vti.Mask srcvalue),
1278+
VLOpFrag),
1279+
(vti.Mask V0),
1280+
VLOpFrag)),
1281+
(!cast<Instruction>("PseudoVFNMADD_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
1282+
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
1283+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1284+
1285+
def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
1286+
(riscv_fneg_vl vti.RegClass:$rd,
1287+
(vti.Mask srcvalue),
12341288
VLOpFrag),
12351289
vti.RegClass:$rs2,
12361290
(vti.Mask true_mask),
12371291
VLOpFrag)),
12381292
(!cast<Instruction>("PseudoVFNMSUB_V" # vti.ScalarSuffix # "_" # suffix)
12391293
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
12401294
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1295+
def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
1296+
(riscv_fneg_vl vti.RegClass:$rd,
1297+
(vti.Mask srcvalue),
1298+
VLOpFrag),
1299+
vti.RegClass:$rs2,
1300+
(vti.Mask V0),
1301+
VLOpFrag)),
1302+
(!cast<Instruction>("PseudoVFNMSUB_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
1303+
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
1304+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
12411305

12421306
// The splat might be negated.
12431307
def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl (SplatFPOp vti.ScalarRegClass:$rs1),
1244-
(vti.Mask true_mask),
1308+
(vti.Mask srcvalue),
12451309
VLOpFrag),
12461310
vti.RegClass:$rd,
12471311
(riscv_fneg_vl vti.RegClass:$rs2,
1248-
(vti.Mask true_mask),
1312+
(vti.Mask srcvalue),
12491313
VLOpFrag),
12501314
(vti.Mask true_mask),
12511315
VLOpFrag)),
12521316
(!cast<Instruction>("PseudoVFNMADD_V" # vti.ScalarSuffix # "_" # suffix)
12531317
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
12541318
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
12551319
def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl (SplatFPOp vti.ScalarRegClass:$rs1),
1256-
(vti.Mask true_mask),
1320+
(vti.Mask srcvalue),
1321+
VLOpFrag),
1322+
vti.RegClass:$rd,
1323+
(riscv_fneg_vl vti.RegClass:$rs2,
1324+
(vti.Mask srcvalue),
1325+
VLOpFrag),
1326+
(vti.Mask V0),
1327+
VLOpFrag)),
1328+
(!cast<Instruction>("PseudoVFNMADD_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
1329+
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
1330+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1331+
1332+
def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl (SplatFPOp vti.ScalarRegClass:$rs1),
1333+
(vti.Mask srcvalue),
12571334
VLOpFrag),
12581335
vti.RegClass:$rd, vti.RegClass:$rs2,
12591336
(vti.Mask true_mask),
12601337
VLOpFrag)),
12611338
(!cast<Instruction>("PseudoVFNMSUB_V" # vti.ScalarSuffix # "_" # suffix)
12621339
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
12631340
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1341+
def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl (SplatFPOp vti.ScalarRegClass:$rs1),
1342+
(vti.Mask srcvalue),
1343+
VLOpFrag),
1344+
vti.RegClass:$rd, vti.RegClass:$rs2,
1345+
(vti.Mask V0),
1346+
VLOpFrag)),
1347+
(!cast<Instruction>("PseudoVFNMSUB_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
1348+
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
1349+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
12641350
}
12651351

12661352
// 14.11. Vector Floating-Point MIN/MAX Instructions

0 commit comments

Comments
 (0)