@@ -917,7 +917,7 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
917
917
// / This accounts for cases where there are multiple unit-dims, but only a
918
918
// / subset of those are dropped. For MemRefTypes these can be disambiguated
919
919
// / using the strides. If a dimension is dropped the stride must be dropped too.
920
- static FailureOr <llvm::SmallBitVector>
920
+ static std::optional <llvm::SmallBitVector>
921
921
computeMemRefRankReductionMask (MemRefType originalType, MemRefType reducedType,
922
922
ArrayRef<OpFoldResult> sizes) {
923
923
llvm::SmallBitVector unusedDims (originalType.getRank ());
@@ -941,7 +941,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
941
941
getStridesAndOffset (originalType, originalStrides, originalOffset)) ||
942
942
failed (
943
943
getStridesAndOffset (reducedType, candidateStrides, candidateOffset)))
944
- return failure () ;
944
+ return std::nullopt ;
945
945
946
946
// For memrefs, a dimension is truly dropped if its corresponding stride is
947
947
// also dropped. This is particularly important when more than one of the dims
@@ -976,22 +976,22 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
976
976
candidateStridesNumOccurences[originalStride]) {
977
977
// This should never happen. Cant have a stride in the reduced rank type
978
978
// that wasnt in the original one.
979
- return failure () ;
979
+ return std::nullopt ;
980
980
}
981
981
}
982
982
983
983
if ((int64_t )unusedDims.count () + reducedType.getRank () !=
984
984
originalType.getRank ())
985
- return failure () ;
985
+ return std::nullopt ;
986
986
return unusedDims;
987
987
}
988
988
989
989
llvm::SmallBitVector SubViewOp::getDroppedDims () {
990
990
MemRefType sourceType = getSourceType ();
991
991
MemRefType resultType = getType ();
992
- FailureOr <llvm::SmallBitVector> unusedDims =
992
+ std::optional <llvm::SmallBitVector> unusedDims =
993
993
computeMemRefRankReductionMask (sourceType, resultType, getMixedSizes ());
994
- assert (succeeded ( unusedDims) && " unable to find unused dims of subview" );
994
+ assert (unusedDims && " unable to find unused dims of subview" );
995
995
return *unusedDims;
996
996
}
997
997
@@ -2745,7 +2745,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2745
2745
// / For ViewLikeOpInterface.
2746
2746
Value SubViewOp::getViewSource () { return getSource (); }
2747
2747
2748
- // / Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2748
+ // / Return true if t1 and t2 have equal offsets (both dynamic or of same
2749
2749
// / static value).
2750
2750
static bool haveCompatibleOffsets (MemRefType t1, MemRefType t2) {
2751
2751
int64_t t1Offset, t2Offset;
@@ -2755,41 +2755,56 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2755
2755
return succeeded (res1) && succeeded (res2) && t1Offset == t2Offset;
2756
2756
}
2757
2757
2758
- // / Return true if `t1` and `t2` have equal strides (both dynamic or of same
2759
- // / static value).
2760
- static bool haveCompatibleStrides (MemRefType t1, MemRefType t2) {
2761
- int64_t t1Offset, t2Offset;
2762
- SmallVector<int64_t > t1Strides, t2Strides;
2763
- auto res1 = getStridesAndOffset (t1, t1Strides, t1Offset);
2764
- auto res2 = getStridesAndOffset (t2, t2Strides, t2Offset);
2765
- if (failed (res1) || failed (res2))
2766
- return false ;
2767
- for (auto [s1, s2] : llvm::zip_equal (t1Strides, t2Strides))
2768
- if (s1 != s2)
2769
- return false ;
2770
- return true ;
2758
+ // / Checks if `original` Type type can be rank reduced to `reduced` type.
2759
+ // / This function is slight variant of `is subsequence` algorithm where
2760
+ // / not matching dimension must be 1.
2761
+ static SliceVerificationResult
2762
+ isRankReducedMemRefType (MemRefType originalType,
2763
+ MemRefType candidateRankReducedType,
2764
+ ArrayRef<OpFoldResult> sizes) {
2765
+ auto partialRes = isRankReducedType (originalType, candidateRankReducedType);
2766
+ if (partialRes != SliceVerificationResult::Success)
2767
+ return partialRes;
2768
+
2769
+ auto optionalUnusedDimsMask = computeMemRefRankReductionMask (
2770
+ originalType, candidateRankReducedType, sizes);
2771
+
2772
+ // Sizes cannot be matched in case empty vector is returned.
2773
+ if (!optionalUnusedDimsMask)
2774
+ return SliceVerificationResult::LayoutMismatch;
2775
+
2776
+ if (originalType.getMemorySpace () !=
2777
+ candidateRankReducedType.getMemorySpace ())
2778
+ return SliceVerificationResult::MemSpaceMismatch;
2779
+
2780
+ // No amount of stride dropping can reconcile incompatible offsets.
2781
+ if (!haveCompatibleOffsets (originalType, candidateRankReducedType))
2782
+ return SliceVerificationResult::LayoutMismatch;
2783
+
2784
+ return SliceVerificationResult::Success;
2771
2785
}
2772
2786
2787
+ template <typename OpTy>
2773
2788
static LogicalResult produceSubViewErrorMsg (SliceVerificationResult result,
2774
- Operation * op, Type expectedType) {
2789
+ OpTy op, Type expectedType) {
2775
2790
auto memrefType = llvm::cast<ShapedType>(expectedType);
2776
2791
switch (result) {
2777
2792
case SliceVerificationResult::Success:
2778
2793
return success ();
2779
2794
case SliceVerificationResult::RankTooLarge:
2780
- return op-> emitError (" expected result rank to be smaller or equal to " )
2795
+ return op. emitError (" expected result rank to be smaller or equal to " )
2781
2796
<< " the source rank. " ;
2782
2797
case SliceVerificationResult::SizeMismatch:
2783
- return op-> emitError (" expected result type to be " )
2798
+ return op. emitError (" expected result type to be " )
2784
2799
<< expectedType
2785
2800
<< " or a rank-reduced version. (mismatch of result sizes) " ;
2786
2801
case SliceVerificationResult::ElemTypeMismatch:
2787
- return op-> emitError (" expected result element type to be " )
2802
+ return op. emitError (" expected result element type to be " )
2788
2803
<< memrefType.getElementType ();
2789
2804
case SliceVerificationResult::MemSpaceMismatch:
2790
- return op-> emitError (" expected result and source memory spaces to match." );
2805
+ return op. emitError (" expected result and source memory spaces to match." );
2791
2806
case SliceVerificationResult::LayoutMismatch:
2792
- return op-> emitError (" expected result type to be " )
2807
+ return op. emitError (" expected result type to be " )
2793
2808
<< expectedType
2794
2809
<< " or a rank-reduced version. (mismatch of result layout) " ;
2795
2810
}
@@ -2811,46 +2826,13 @@ LogicalResult SubViewOp::verify() {
2811
2826
if (!isStrided (baseType))
2812
2827
return emitError (" base type " ) << baseType << " is not strided" ;
2813
2828
2814
- // Compute the expected result type, assuming that there are no rank
2815
- // reductions.
2816
- auto expectedType = cast<MemRefType>(SubViewOp::inferResultType (
2817
- baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ()));
2818
-
2819
- // Verify all properties of a shaped type: rank, element type and dimension
2820
- // sizes. This takes into account potential rank reductions.
2821
- auto shapedTypeVerification = isRankReducedType (
2822
- /* originalType=*/ expectedType, /* candidateReducedType=*/ subViewType);
2823
- if (shapedTypeVerification != SliceVerificationResult::Success)
2824
- return produceSubViewErrorMsg (shapedTypeVerification, *this , expectedType);
2825
-
2826
- // Make sure that the memory space did not change.
2827
- if (expectedType.getMemorySpace () != subViewType.getMemorySpace ())
2828
- return produceSubViewErrorMsg (SliceVerificationResult::MemSpaceMismatch,
2829
- *this , expectedType);
2830
-
2831
- // Verify the offset of the layout map.
2832
- if (!haveCompatibleOffsets (expectedType, subViewType))
2833
- return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2834
- *this , expectedType);
2835
-
2836
- // The only thing that's left to verify now are the strides. First, compute
2837
- // the unused dimensions due to rank reductions. We have to look at sizes and
2838
- // strides to decide which dimensions were dropped. This function also
2839
- // partially verifies strides in case of rank reductions.
2840
- auto unusedDims = computeMemRefRankReductionMask (expectedType, subViewType,
2841
- getMixedSizes ());
2842
- if (failed (unusedDims))
2843
- return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2844
- *this , expectedType);
2845
-
2846
- // Strides must match if there are no rank reductions.
2847
- // TODO: Verify strides when there are rank reductions. Strides are partially
2848
- // checked in `computeMemRefRankReductionMask`.
2849
- if (unusedDims->none () && !haveCompatibleStrides (expectedType, subViewType))
2850
- return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2851
- *this , expectedType);
2829
+ // Verify result type against inferred type.
2830
+ auto expectedType = SubViewOp::inferResultType (
2831
+ baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ());
2852
2832
2853
- return success ();
2833
+ auto result = isRankReducedMemRefType (llvm::cast<MemRefType>(expectedType),
2834
+ subViewType, getMixedSizes ());
2835
+ return produceSubViewErrorMsg (result, *this , expectedType);
2854
2836
}
2855
2837
2856
2838
raw_ostream &mlir::operator <<(raw_ostream &os, const Range &range) {
@@ -2900,9 +2882,11 @@ static MemRefType getCanonicalSubViewResultType(
2900
2882
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2901
2883
auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType (
2902
2884
sourceType, mixedOffsets, mixedSizes, mixedStrides));
2903
- FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask (
2904
- currentSourceType, currentResultType, mixedSizes);
2905
- if (failed (unusedDims))
2885
+ std::optional<llvm::SmallBitVector> unusedDims =
2886
+ computeMemRefRankReductionMask (currentSourceType, currentResultType,
2887
+ mixedSizes);
2888
+ // Return nullptr as failure mode.
2889
+ if (!unusedDims)
2906
2890
return nullptr ;
2907
2891
2908
2892
auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout ());
0 commit comments