Skip to content

Commit d905b1c

Browse files
bjacobbanach-space
andauthored
[MLIR] Vector dialect: Address post-merge review comments on #111541 (#111552)
Co-authored-by: Andrzej Warzyński <[email protected]>
1 parent 3829fd7 commit d905b1c

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,17 @@ class DecomposeNDExtractStridedSlice
331331

332332
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
333333
/// slice is contiguous, into extract and shape_cast.
334+
///
335+
/// Example:
336+
/// Before:
337+
/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
338+
/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
339+
/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
340+
/// After:
341+
/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
342+
/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
343+
/// vector<1x1x1x1x8xi8>
344+
///
334345
class ContiguousExtractStridedSliceToExtract final
335346
: public OpRewritePattern<ExtractStridedSliceOp> {
336347
public:
Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,24 @@
11
// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
22

3-
// CHECK-LABEL: @extract_strided_slice_to_extract_i8
4-
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
5-
// CHECK: return %[[EXTRACT]] : vector<8xi8>
6-
func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
7-
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
8-
%2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
9-
return %2 : vector<8xi8>
10-
}
11-
12-
// CHECK-LABEL: @extract_strided_slice_to_extract_i32
3+
// CHECK-LABEL: @contiguous
134
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
14-
// CHECK: return %[[EXTRACT]] : vector<4xi32>
15-
func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
5+
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
6+
func.func @contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
167
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
178
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
189
return %2 : vector<4xi32>
1910
}
2011

21-
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
22-
// CHECK: vector.extract_strided_slice
23-
func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
12+
// CHECK-LABEL: @non_full_size
13+
// CHECK-NEXT: vector.extract_strided_slice
14+
func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
2415
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
25-
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
26-
return %2 : vector<2xi32>
16+
return %1 : vector<1x1x1x1x1x2xi32>
2717
}
2818

29-
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
30-
// CHECK: vector.extract_strided_slice
31-
func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
19+
// CHECK-LABEL: @non_full_inner_size
20+
// CHECK-NEXT: vector.extract_strided_slice
21+
func.func @non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
3222
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
33-
%2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
34-
return %2 : vector<2xi32>
23+
return %1 : vector<1x1x2x1x1x1xi32>
3524
}

0 commit comments

Comments
 (0)