Skip to content

Commit f0e6c8b

Browse files
committed
add AllowExtMask
1 parent b7ebaeb commit f0e6c8b

File tree

1 file changed

+20
-27
lines changed

1 file changed

+20
-27
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13316,8 +13316,7 @@ namespace {
1331613316
// apply a combine.
1331713317
struct CombineResult;
1331813318

13319-
enum class ExtKind { ZExt, SExt, FPExt };
13320-
13319+
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
1332113320
/// Helper class for folding sign/zero extensions.
1332213321
/// In particular, this class is used for the following combines:
1332313322
/// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -13448,13 +13447,11 @@ struct NodeExtensionHelper {
1344813447
// Determine the narrow size.
1344913448
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1345013449

13451-
unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
13452-
1345313450
MVT EltVT = SupportsExt == ExtKind::FPExt
1345413451
? MVT::getFloatingPointVT(NarrowSize)
1345513452
: MVT::getIntegerVT(NarrowSize);
1345613453

13457-
assert(NarrowSize >= NarrowMinSize &&
13454+
assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
1345813455
"Trying to extend something we can't represent");
1345913456
MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
1346013457
return NarrowVT;
@@ -13823,33 +13820,32 @@ struct CombineResult {
1382313820
/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
1382413821
/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
1382513822
/// are zext) and LHS and RHS can be folded into Root.
13826-
/// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
13823+
/// AllowExtMask define which form `ext` can take in this pattern.
1382713824
///
1382813825
/// \note If the pattern can match with both zext and sext, the returned
1382913826
/// CombineResult will feature the zext result.
1383013827
///
1383113828
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1383213829
/// can be used to apply the pattern.
13833-
static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
13834-
SDNode *Root, const NodeExtensionHelper &LHS,
13835-
const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
13836-
bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
13837-
assert((AllowSExt || AllowZExt || AllowFPExt) &&
13838-
"Forgot to set what you want?");
13830+
static std::optional<CombineResult>
13831+
canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
13832+
const NodeExtensionHelper &RHS,
13833+
uint8_t AllowExtMask, SelectionDAG &DAG,
13834+
const RISCVSubtarget &Subtarget) {
1383913835
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
1384013836
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
1384113837
return std::nullopt;
13842-
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13838+
if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
1384313839
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
1384413840
Root->getOpcode(), ExtKind::ZExt),
1384513841
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
1384613842
/*RHSExt=*/{ExtKind::ZExt});
13847-
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
13843+
if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
1384813844
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
1384913845
Root->getOpcode(), ExtKind::SExt),
1385013846
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
1385113847
/*RHSExt=*/{ExtKind::SExt});
13852-
if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
13848+
if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
1385313849
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
1385413850
Root->getOpcode(), ExtKind::FPExt),
1385513851
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
@@ -13867,9 +13863,9 @@ static std::optional<CombineResult>
1386713863
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1386813864
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1386913865
const RISCVSubtarget &Subtarget) {
13870-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13871-
/*AllowZExt=*/true,
13872-
/*AllowFPExt=*/true, DAG, Subtarget);
13866+
return canFoldToVWWithSameExtensionImpl(
13867+
Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
13868+
Subtarget);
1387313869
}
1387413870

1387513871
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13911,9 +13907,8 @@ static std::optional<CombineResult>
1391113907
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1391213908
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1391313909
const RISCVSubtarget &Subtarget) {
13914-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13915-
/*AllowZExt=*/false,
13916-
/*AllowFPExt=*/false, DAG, Subtarget);
13910+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
13911+
Subtarget);
1391713912
}
1391813913

1391913914
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13924,9 +13919,8 @@ static std::optional<CombineResult>
1392413919
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1392513920
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1392613921
const RISCVSubtarget &Subtarget) {
13927-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13928-
/*AllowZExt=*/true,
13929-
/*AllowFPExt=*/false, DAG, Subtarget);
13922+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
13923+
Subtarget);
1393013924
}
1393113925

1393213926
/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
@@ -13937,9 +13931,8 @@ static std::optional<CombineResult>
1393713931
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1393813932
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1393913933
const RISCVSubtarget &Subtarget) {
13940-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13941-
/*AllowZExt=*/false,
13942-
/*AllowFPExt=*/true, DAG, Subtarget);
13934+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
13935+
Subtarget);
1394313936
}
1394413937

1394513938
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))

0 commit comments

Comments
 (0)