@@ -3777,13 +3777,16 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3777
3777
// /
3778
3778
// / Example:
3779
3779
// / Before:
3780
- // / %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
3781
- // / sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
3782
- // / vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3780
+ // / %1 = vector.extract_strided_slice %arg0 {
3781
+ // / offsets = [0, 0, 0, 0, 0],
3782
+ // / sizes = [1, 1, 1, 1, 8],
3783
+ // / strides = [1, 1, 1, 1, 1]
3784
+ // / } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3783
3785
// / After:
3784
- // / %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
3785
- // / vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
3786
- // / vector<1x1x1x1x8xi8>
3786
+ // / %0 = vector.extract %arg0[0, 0, 0, 0]
3787
+ // / : vector<8xi8> from vector<8x1x1x2x8xi8>
3788
+ // / %1 = vector.shape_cast %0
3789
+ // / : vector<8xi8> to vector<1x1x1x1x8xi8>
3787
3790
// /
3788
3791
class ContiguousExtractStridedSliceToExtract final
3789
3792
: public OpRewritePattern<ExtractStridedSliceOp> {
@@ -3792,14 +3795,12 @@ class ContiguousExtractStridedSliceToExtract final
3792
3795
3793
3796
LogicalResult matchAndRewrite (ExtractStridedSliceOp op,
3794
3797
PatternRewriter &rewriter) const override {
3795
- if (op.hasNonUnitStrides ()) {
3798
+ if (op.hasNonUnitStrides ())
3796
3799
return failure ();
3797
- }
3798
3800
Value source = op.getOperand ();
3799
3801
auto sourceType = cast<VectorType>(source.getType ());
3800
- if (sourceType.isScalable () || sourceType.getRank () == 0 ) {
3802
+ if (sourceType.isScalable () || sourceType.getRank () == 0 )
3801
3803
return failure ();
3802
- }
3803
3804
3804
3805
// Compute the number of offsets to pass to ExtractOp::build. That is the
3805
3806
// difference between the source rank and the desired slice rank. We walk
@@ -3808,30 +3809,26 @@ class ContiguousExtractStridedSliceToExtract final
3808
3809
SmallVector<int64_t > sizes = getI64SubArray (op.getSizes ());
3809
3810
int numOffsets;
3810
3811
for (numOffsets = sizes.size (); numOffsets > 0 ; --numOffsets) {
3811
- if (sizes[numOffsets - 1 ] != sourceType.getDimSize (numOffsets - 1 )) {
3812
+ if (sizes[numOffsets - 1 ] != sourceType.getDimSize (numOffsets - 1 ))
3812
3813
break ;
3813
- }
3814
3814
}
3815
3815
3816
3816
// If the created extract op would have no offsets, then this whole
3817
3817
// extract_strided_slice is the identity and should have been handled by
3818
3818
// other canonicalizations.
3819
- if (numOffsets == 0 ) {
3819
+ if (numOffsets == 0 )
3820
3820
return failure ();
3821
- }
3822
3821
3823
3822
// If not even the inner-most dimension is full-size, this op can't be
3824
3823
// rewritten as an ExtractOp.
3825
3824
if (numOffsets == sourceType.getRank () &&
3826
- static_cast <int >(sizes.size ()) == sourceType.getRank ()) {
3825
+ static_cast <int >(sizes.size ()) == sourceType.getRank ())
3827
3826
return failure ();
3828
- }
3829
3827
3830
3828
// The outer dimensions must have unit size.
3831
3829
for (int i = 0 ; i < numOffsets; ++i) {
3832
- if (sizes[i] != 1 ) {
3830
+ if (sizes[i] != 1 )
3833
3831
return failure ();
3834
- }
3835
3832
}
3836
3833
3837
3834
// Avoid generating slices that have leading unit dimensions. The shape_cast
@@ -3841,13 +3838,6 @@ class ContiguousExtractStridedSliceToExtract final
3841
3838
numOffsets < static_cast <int >(sizes.size ()) - 1 ) {
3842
3839
++numOffsets;
3843
3840
}
3844
- // After exhausting the list of slice sizes, we keep checking for unit
3845
- // dimensions in the source shape, to remove corner cases where the result
3846
- // would have a leading unit dimension.
3847
- while (sourceType.getDimSize (numOffsets) == 1 &&
3848
- numOffsets < sourceType.getRank () - 1 ) {
3849
- ++numOffsets;
3850
- }
3851
3841
3852
3842
SmallVector<int64_t > offsets = getI64SubArray (op.getOffsets ());
3853
3843
auto extractOffsets = ArrayRef (offsets).take_front (numOffsets);
0 commit comments