Skip to content

Commit 1632e3a

Browse files
fix nit and update test.
1 parent 64ad5c6 commit 1632e3a

File tree

2 files changed

+47
-39
lines changed

2 files changed

+47
-39
lines changed

mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
3131

3232
LogicalResult matchAndRewrite(memref::SubViewOp op,
3333
PatternRewriter &rewriter) const override {
34-
3534
// 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
3635
// produces the input of the op we're rewriting (for 'SubViewOp' the input
3736
// is called the "source" value). We can only combine them if both 'op' and
@@ -51,13 +50,14 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
5150
}
5251

5352
// Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
54-
SmallVector<OpFoldResult> offsets, sizes, strides;
55-
auto opStrides = op.getMixedStrides();
56-
auto sourceStrides = sourceOp.getMixedStrides();
53+
SmallVector<OpFoldResult> offsets, sizes, strides,
54+
opStrides = op.getMixedStrides(),
55+
sourceStrides = sourceOp.getMixedStrides();
5756

5857
// The output stride in each dimension is equal to the product of the
5958
// dimensions corresponding to source and op.
60-
for (auto [opStride, sourceStride] : llvm::zip(opStrides, sourceStrides)) {
59+
for (auto &&[opStride, sourceStride] :
60+
llvm::zip(opStrides, sourceStrides)) {
6161
Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
6262
Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
6363
if (!opStrideAttr || !sourceStrideAttr)
@@ -77,7 +77,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
7777
// by definition (a subview needs to be the same size as or smaller than
7878
// its source along each dimension; presumably subviews that are larger
7979
// than their sources are disallowed by validation).
80-
for (auto [opOffset, sourceOffset, sourceStride, opSize] :
80+
for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
8181
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
8282
sourceOp.getMixedStrides(), op.getMixedSizes())) {
8383
// We only support static sizes.
@@ -103,13 +103,13 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
103103
AffineExpr expr1 = rewriter.getAffineConstantExpr(0);
104104
SmallVector<Value> affineApplyOperands;
105105
SmallVector<OpFoldResult> opOffsets{sourceOffset, opOffset};
106-
for (auto [idx, offset] : llvm::enumerate(opOffsets)) {
106+
for (auto &&[idx, offset] : llvm::enumerate(opOffsets)) {
107107
if (auto attr = llvm::dyn_cast_if_present<Attribute>(offset)) {
108108
if (idx == 0) {
109109
expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
110110
} else if (idx == 1) {
111-
expr1 = expr1 + cast<IntegerAttr>(attr).getInt();
112-
expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
111+
expr1 = expr1 + cast<IntegerAttr>(attr).getInt() *
112+
cast<IntegerAttr>(sourceStrideAttr).getInt();
113113
expr0 = expr0 + expr1;
114114
}
115115
} else {

mlir/test/Transforms/compose-subview.mlir

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
// RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s
22

3-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
4-
// CHECK: subview %arg0[3, 384] [1, 128] [1, 1]
5-
// CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
3+
// CHECK-LABEL: func.func @subview_strided(
4+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
5+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
6+
// CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
67
%0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
78
%1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
89
return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>>
910
}
1011

1112
// -----
1213

13-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
14-
// CHECK: subview %arg0[3, 673] [1, 10] [1, 1]
15-
// CHECK-SAME: memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
14+
// CHECK-LABEL: func.func @subview_strided(
15+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
16+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
17+
// CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
1618
%0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>>
1719
%1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>>
1820
%2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
@@ -21,46 +23,50 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1
2123

2224
// -----
2325

24-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
25-
// CHECK: [[CST_3:%.*]] = arith.constant 3 : index
26+
// CHECK-LABEL: func.func @subview_strided(
27+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
28+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
29+
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
2630
%cst_1 = arith.constant 1 : index
2731
%cst_2 = arith.constant 2 : index
28-
// CHECK: subview %arg0{{\[}}[[CST_3]], 384] [1, 128] [1, 1]
29-
// CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
32+
// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
3033
%0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
3134
%1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
3235
return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
3336
}
3437

3538
// -----
3639

37-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
38-
// CHECK: [[CST_3:%.*]] = arith.constant 3 : index
40+
// CHECK-LABEL: func.func @subview_strided(
41+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
42+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
43+
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
3944
%cst_2 = arith.constant 2 : index
40-
// CHECK: [[CST_384:%.*]] = arith.constant 384 : index
45+
// CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
4146
%cst_128 = arith.constant 128 : index
42-
// CHECK: subview %arg0{{\[}}[[CST_3]], [[CST_384]]] [1, 128] [1, 1]
43-
// CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
47+
// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
4448
%0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
4549
%1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
4650
return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
4751
}
4852

4953
// -----
5054

51-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
52-
// CHECK: subview %arg0[4, 384] [1, 64] [4, 4]
53-
// CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
55+
// CHECK-LABEL: func.func @subview_strided(
56+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
57+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
58+
// CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
5459
%0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>>
5560
%1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
5661
return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>>
5762
}
5863

5964
// -----
6065

61-
func.func @main(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
62-
// CHECK: subview %arg0[7, 7] [2, 2] [8, 8]
63-
// CHECK-SAME: memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
66+
// CHECK-LABEL: func.func @subview_strided(
67+
// CHECK-SAME: %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
68+
func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
69+
// CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
6470
%0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>>
6571
%1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>>
6672
%2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>>
@@ -69,26 +75,28 @@ func.func @main(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8],
6975

7076
// -----
7177

72-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
73-
// CHECK:%[[VAL_1:.*]] = arith.constant 4 : index
78+
// CHECK-LABEL: func.func @subview_strided(
79+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
80+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
81+
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
7482
%cst_2 = arith.constant 2 : index
75-
// CHECK:%[[VAL_2:.*]] = arith.constant 384 : index
83+
// CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
7684
%cst_64 = arith.constant 64 : index
77-
// CHECK: subview %arg0{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4]
78-
// CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
85+
// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
7986
%0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
8087
%1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
8188
return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
8289
}
8390

8491
// -----
8592

86-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
87-
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
93+
// CHECK-LABEL: func.func @subview_strided(
94+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
95+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
96+
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
8897
%cst_1 = arith.constant 1 : index
8998
%cst_2 = arith.constant 2 : index
90-
// CHECK: subview %arg0{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4]
91-
// CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
99+
// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
92100
%0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
93101
%1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
94102
return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>

0 commit comments

Comments
 (0)