@@ -13316,8 +13316,7 @@ namespace {
13316
13316
// apply a combine.
13317
13317
struct CombineResult;
13318
13318
13319
- enum class ExtKind { ZExt, SExt, FPExt };
13320
-
13319
+ enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
13321
13320
/// Helper class for folding sign/zero extensions.
13322
13321
/// In particular, this class is used for the following combines:
13323
13322
/// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -13448,13 +13447,11 @@ struct NodeExtensionHelper {
13448
13447
// Determine the narrow size.
13449
13448
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
13450
13449
13451
- unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
13452
-
13453
13450
MVT EltVT = SupportsExt == ExtKind::FPExt
13454
13451
? MVT::getFloatingPointVT(NarrowSize)
13455
13452
: MVT::getIntegerVT(NarrowSize);
13456
13453
13457
- assert(NarrowSize >= NarrowMinSize &&
13454
+ assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
13458
13455
"Trying to extend something we can't represent");
13459
13456
MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
13460
13457
return NarrowVT;
@@ -13823,33 +13820,32 @@ struct CombineResult {
13823
13820
/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
13824
13821
/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
13825
13822
/// are zext) and LHS and RHS can be folded into Root.
13826
- /// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
13823
+ /// AllowExtMask define which form `ext` can take in this pattern.
13827
13824
///
13828
13825
/// \note If the pattern can match with both zext and sext, the returned
13829
13826
/// CombineResult will feature the zext result.
13830
13827
///
13831
13828
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
13832
13829
/// can be used to apply the pattern.
13833
- static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
13834
- SDNode *Root, const NodeExtensionHelper &LHS,
13835
- const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
13836
- bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
13837
- assert((AllowSExt || AllowZExt || AllowFPExt) &&
13838
- "Forgot to set what you want?");
13830
+ static std::optional<CombineResult>
13831
+ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
13832
+ const NodeExtensionHelper &RHS,
13833
+ uint8_t AllowExtMask, SelectionDAG &DAG,
13834
+ const RISCVSubtarget &Subtarget) {
13839
13835
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
13840
13836
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
13841
13837
return std::nullopt;
13842
- if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13838
+ if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13843
13839
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13844
13840
Root->getOpcode(), ExtKind::ZExt),
13845
13841
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
13846
13842
/*RHSExt=*/{ExtKind::ZExt});
13847
- if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
13843
+ if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
13848
13844
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13849
13845
Root->getOpcode(), ExtKind::SExt),
13850
13846
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
13851
13847
/*RHSExt=*/{ExtKind::SExt});
13852
- if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
13848
+ if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
13853
13849
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13854
13850
Root->getOpcode(), ExtKind::FPExt),
13855
13851
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
@@ -13867,9 +13863,9 @@ static std::optional<CombineResult>
13867
13863
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
13868
13864
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13869
13865
const RISCVSubtarget &Subtarget) {
13870
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13871
- /*AllowZExt=*/true ,
13872
- /*AllowFPExt=*/true, DAG, Subtarget);
13866
+ return canFoldToVWWithSameExtensionImpl(
13867
+ Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG ,
13868
+ Subtarget);
13873
13869
}
13874
13870
13875
13871
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13911,9 +13907,8 @@ static std::optional<CombineResult>
13911
13907
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
13912
13908
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13913
13909
const RISCVSubtarget &Subtarget) {
13914
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13915
- /*AllowZExt=*/false,
13916
- /*AllowFPExt=*/false, DAG, Subtarget);
13910
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
13911
+ Subtarget);
13917
13912
}
13918
13913
13919
13914
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13924,9 +13919,8 @@ static std::optional<CombineResult>
13924
13919
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
13925
13920
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13926
13921
const RISCVSubtarget &Subtarget) {
13927
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13928
- /*AllowZExt=*/true,
13929
- /*AllowFPExt=*/false, DAG, Subtarget);
13922
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
13923
+ Subtarget);
13930
13924
}
13931
13925
13932
13926
/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
@@ -13937,9 +13931,8 @@ static std::optional<CombineResult>
13937
13931
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
13938
13932
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13939
13933
const RISCVSubtarget &Subtarget) {
13940
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13941
- /*AllowZExt=*/false,
13942
- /*AllowFPExt=*/true, DAG, Subtarget);
13934
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
13935
+ Subtarget);
13943
13936
}
13944
13937
13945
13938
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
0 commit comments