Skip to content

Commit 17d9d8e

Browse files
committed
review comments and fix
Signed-off-by: Benoit Jacob <[email protected]>
1 parent 80fc72b commit 17d9d8e

File tree

2 files changed

+20
-30
lines changed

2 files changed

+20
-30
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3777,13 +3777,16 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
37773777
///
37783778
/// Example:
37793779
/// Before:
3780-
/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
3781-
/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
3782-
/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3780+
/// %1 = vector.extract_strided_slice %arg0 {
3781+
/// offsets = [0, 0, 0, 0, 0],
3782+
/// sizes = [1, 1, 1, 1, 8],
3783+
/// strides = [1, 1, 1, 1, 1]
3784+
/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
37833785
/// After:
3784-
/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
3785-
/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
3786-
/// vector<1x1x1x1x8xi8>
3786+
/// %0 = vector.extract %arg0[0, 0, 0, 0]
3787+
/// : vector<8xi8> from vector<8x1x1x2x8xi8>
3788+
/// %1 = vector.shape_cast %0
3789+
/// : vector<8xi8> to vector<1x1x1x1x8xi8>
37873790
///
37883791
class ContiguousExtractStridedSliceToExtract final
37893792
: public OpRewritePattern<ExtractStridedSliceOp> {
@@ -3792,14 +3795,12 @@ class ContiguousExtractStridedSliceToExtract final
37923795

37933796
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
37943797
PatternRewriter &rewriter) const override {
3795-
if (op.hasNonUnitStrides()) {
3798+
if (op.hasNonUnitStrides())
37963799
return failure();
3797-
}
37983800
Value source = op.getOperand();
37993801
auto sourceType = cast<VectorType>(source.getType());
3800-
if (sourceType.isScalable() || sourceType.getRank() == 0) {
3802+
if (sourceType.isScalable() || sourceType.getRank() == 0)
38013803
return failure();
3802-
}
38033804

38043805
// Compute the number of offsets to pass to ExtractOp::build. That is the
38053806
// difference between the source rank and the desired slice rank. We walk
@@ -3808,30 +3809,26 @@ class ContiguousExtractStridedSliceToExtract final
38083809
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
38093810
int numOffsets;
38103811
for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3811-
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
3812+
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
38123813
break;
3813-
}
38143814
}
38153815

38163816
// If the created extract op would have no offsets, then this whole
38173817
// extract_strided_slice is the identity and should have been handled by
38183818
// other canonicalizations.
3819-
if (numOffsets == 0) {
3819+
if (numOffsets == 0)
38203820
return failure();
3821-
}
38223821

38233822
// If not even the inner-most dimension is full-size, this op can't be
38243823
// rewritten as an ExtractOp.
38253824
if (numOffsets == sourceType.getRank() &&
3826-
static_cast<int>(sizes.size()) == sourceType.getRank()) {
3825+
static_cast<int>(sizes.size()) == sourceType.getRank())
38273826
return failure();
3828-
}
38293827

38303828
// The outer dimensions must have unit size.
38313829
for (int i = 0; i < numOffsets; ++i) {
3832-
if (sizes[i] != 1) {
3830+
if (sizes[i] != 1)
38333831
return failure();
3834-
}
38353832
}
38363833

38373834
// Avoid generating slices that have leading unit dimensions. The shape_cast
@@ -3841,13 +3838,6 @@ class ContiguousExtractStridedSliceToExtract final
38413838
numOffsets < static_cast<int>(sizes.size()) - 1) {
38423839
++numOffsets;
38433840
}
3844-
// After exhausting the list of slice sizes, we keep checking for unit
3845-
// dimensions in the source shape, to remove corner cases where the result
3846-
// would have a leading unit dimension.
3847-
while (sourceType.getDimSize(numOffsets) == 1 &&
3848-
numOffsets < sourceType.getRank() - 1) {
3849-
++numOffsets;
3850-
}
38513841

38523842
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
38533843
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2757,12 +2757,12 @@ func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1
27572757
// -----
27582758

27592759
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
2760-
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
2761-
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
2762-
func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
2760+
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<1x4xi32> from vector<8x1x2x1x1x4xi32>
2761+
// CHECK-NEXT: return %[[EXTRACT]] : vector<1x4xi32>
2762+
func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x4xi32> {
27632763
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1], strides = [1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
2764-
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
2765-
return %2 : vector<4xi32>
2764+
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<1x4xi32>
2765+
return %2 : vector<1x4xi32>
27662766
}
27672767

27682768
// -----

0 commit comments

Comments
 (0)