Skip to content

Commit 2ecf608

Browse files
[mlir]Fix compose subview (#80551)
I found a bug in `test-compose-subview`,You can see the example I gave. ``` #map = affine_map<() -> ()> module { func.func private @fun(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> { %c0 = arith.constant 0 : index %c5 = arith.constant 5 : index %c1 = arith.constant 1 : index %subview = memref.subview %arg0[0, 0] [5, 5] [1, 1] : memref<10x10xf32> to memref<5x5xf32, strided<[10, 1]>> %alloc = memref.alloc() : memref<5x5xf32> scf.for %arg2 = %c0 to %c5 step %c1 { scf.for %arg3 = %c0 to %c5 step %c1 { %subview_0 = memref.subview %subview[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32, strided<[10, 1]>> to memref<f32, strided<[], offset: ?>> %subview_1 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>> %alloc_2 = memref.alloc() : memref<f32> linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) { ^bb0(%in: f32, %in_4: f32, %out: f32): %0 = arith.addf %in, %in_4 : f32 linalg.yield %0 : f32 } %subview_3 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>> memref.copy %alloc_2, %subview_3 : memref<f32> to memref<f32, strided<[], offset: ?>> } } return %alloc : memref<5x5xf32> } func.func @test(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> { %0 = call @fun(%arg0, %arg1) : (memref<10x10xf32>, memref<5x5xf32>) -> memref<5x5xf32> return %0 : memref<5x5xf32> } } ``` When I run `mlir-opt test.mlir ---test-compose-subview`. ``` test.mlir:14:9: error: 'linalg.generic' op expected operand rank (2) to match the result rank of indexing_map #0 (0) linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) { ^ test1.mlir:14:9: note: see current operation: "linalg.generic"(%4, %5, %6) <{indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = [], operandSegmentSizes = array<i32: 2, 1>}> ({ ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): %8 = "arith.addf"(%arg4, %arg5) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 "linalg.yield"(%8) : (f32) -> () }) : (memref<1x1xf32, strided<[10, 1], offset: ?>>, memref<f32, strided<[], offset: ?>>, memref<f32>) -> () ``` This PR fixes that.In the meantime I've extended this PR to handle cases where stride is greater than 1. ``` func.func private @UNknown0(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> { %c0 = arith.constant 0 : index %c5 = arith.constant 5 : index %c1 = arith.constant 1 : index %subview = memref.subview %arg0[0, 0] [5, 5] [2, 2] : memref<10x10xf32> to memref<5x5xf32, strided<[20, 2]>> %alloc = memref.alloc() : memref<5x5xf32> scf.for %arg2 = %c0 to %c5 step %c1 { scf.for %arg3 = %c0 to %c5 step %c1 { %subview_0 = memref.subview %subview[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32, strided<[20, 2]>> to memref<f32, strided<[], offset: ?>> %subview_1 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>> %alloc_2 = memref.alloc() : memref<f32> linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) { ^bb0(%in: f32, %in_4: f32, %out: f32): %0 = arith.addf %in, %in_4 : f32 linalg.yield %0 : f32 } %subview_3 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>> memref.copy %alloc_2, %subview_3 : memref<f32> to memref<f32, strided<[], offset: ?>> } } return %alloc : memref<5x5xf32> } $ mlir-opt test.mlir -test-compose-subview #map = affine_map<()[s0] -> (s0 * 2)> #map1 = affine_map<() -> ()> module { func.func private @UNknown0(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> { %c0 = arith.constant 0 : index %c5 = arith.constant 5 : index %c1 = arith.constant 1 : index %alloc = memref.alloc() : memref<5x5xf32> scf.for %arg2 = %c0 to %c5 step %c1 { scf.for %arg3 = %c0 to %c5 step %c1 { %0 = affine.apply #map()[%arg2] %1 = affine.apply #map()[%arg3] %subview = memref.subview %arg0[%0, %1] [1, 1] [2, 2] : memref<10x10xf32> to memref<f32, strided<[], offset: ?>> %subview_0 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>> %alloc_1 = memref.alloc() : memref<f32> linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = []} ins(%subview, %subview_0 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_1 : memref<f32>) { ^bb0(%in: f32, %in_3: f32, %out: f32): %2 = arith.addf %in, %in_3 : f32 linalg.yield %2 : f32 } %subview_2 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>> memref.copy %alloc_1, %subview_2 : memref<f32> to memref<f32, strided<[], offset: ?>> } } return %alloc : memref<5x5xf32> } } ```
1 parent 9c75a98 commit 2ecf608

File tree

2 files changed

+122
-53
lines changed

2 files changed

+122
-53
lines changed

mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ using namespace mlir;
2424

2525
namespace {
2626

27-
// Replaces a subview of a subview with a single subview. Only supports subview
28-
// ops with static sizes and static strides of 1 (both static and dynamic
27+
// Replaces a subview of a subview with a single subview(both static and dynamic
2928
// offsets are supported).
3029
struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
3130
using OpRewritePattern::OpRewritePattern;
@@ -51,64 +50,78 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
5150
}
5251

5352
// Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
54-
SmallVector<OpFoldResult> offsets, sizes, strides;
55-
56-
// Because we only support input strides of 1, the output stride is also
57-
// always 1.
58-
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
59-
Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
60-
return attr && cast<IntegerAttr>(attr).getInt() == 1;
61-
})) {
62-
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
63-
rewriter.getI64IntegerAttr(1));
64-
} else {
65-
return failure();
53+
SmallVector<OpFoldResult> offsets, sizes, strides,
54+
opStrides = op.getMixedStrides(),
55+
sourceStrides = sourceOp.getMixedStrides();
56+
57+
// The output stride in each dimension is equal to the product of the
58+
// dimensions corresponding to source and op.
59+
int64_t sourceStrideValue;
60+
for (auto &&[opStride, sourceStride] :
61+
llvm::zip(opStrides, sourceStrides)) {
62+
Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
63+
Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
64+
if (!opStrideAttr || !sourceStrideAttr)
65+
return failure();
66+
sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt();
67+
strides.push_back(rewriter.getI64IntegerAttr(
68+
cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue));
6669
}
6770

6871
// The rules for calculating the new offsets and sizes are:
6972
// * Multiple subview offsets for a given dimension compose additively.
70-
// ("Offset by m" followed by "Offset by n" == "Offset by m + n")
73+
// ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
74+
// m + n * k")
7175
// * Multiple sizes for a given dimension compose by taking the size of the
7276
// final subview and ignoring the rest. ("Take m values" followed by "Take
7377
// n values" == "Take n values") This size must also be the smallest one
7478
// by definition (a subview needs to be the same size as or smaller than
7579
// its source along each dimension; presumably subviews that are larger
7680
// than their sources are disallowed by validation).
77-
for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
78-
op.getMixedSizes())) {
79-
auto opOffset = std::get<0>(it);
80-
auto sourceOffset = std::get<1>(it);
81-
auto opSize = std::get<2>(it);
82-
81+
for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
82+
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
83+
sourceOp.getMixedStrides(), op.getMixedSizes())) {
8384
// We only support static sizes.
8485
if (opSize.is<Value>()) {
8586
return failure();
8687
}
87-
8888
sizes.push_back(opSize);
8989
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
9090
sourceOffsetAttr =
91-
llvm::dyn_cast_if_present<Attribute>(sourceOffset);
92-
91+
llvm::dyn_cast_if_present<Attribute>(sourceOffset),
92+
sourceStrideAttr =
93+
llvm::dyn_cast_if_present<Attribute>(sourceStride);
9394
if (opOffsetAttr && sourceOffsetAttr) {
95+
9496
// If both offsets are static we can simply calculate the combined
9597
// offset statically.
9698
offsets.push_back(rewriter.getI64IntegerAttr(
97-
cast<IntegerAttr>(opOffsetAttr).getInt() +
99+
cast<IntegerAttr>(opOffsetAttr).getInt() *
100+
cast<IntegerAttr>(sourceStrideAttr).getInt() +
98101
cast<IntegerAttr>(sourceOffsetAttr).getInt()));
99102
} else {
100-
// When either offset is dynamic, we must emit an additional affine
101-
// transformation to add the two offsets together dynamically.
102-
AffineExpr expr = rewriter.getAffineConstantExpr(0);
103+
AffineExpr expr;
103104
SmallVector<Value> affineApplyOperands;
104-
for (auto valueOrAttr : {opOffset, sourceOffset}) {
105-
if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
106-
expr = expr + cast<IntegerAttr>(attr).getInt();
107-
} else {
108-
expr =
109-
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
110-
affineApplyOperands.push_back(valueOrAttr.get<Value>());
111-
}
105+
106+
// Make 'expr' add 'sourceOffset'.
107+
if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
108+
expr =
109+
rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
110+
} else {
111+
expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
112+
affineApplyOperands.push_back(sourceOffset.get<Value>());
113+
}
114+
115+
// Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
116+
// result.
117+
if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
118+
expr = expr + cast<IntegerAttr>(attr).getInt() *
119+
cast<IntegerAttr>(sourceStrideAttr).getInt();
120+
} else {
121+
expr =
122+
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
123+
cast<IntegerAttr>(sourceStrideAttr).getInt();
124+
affineApplyOperands.push_back(opOffset.get<Value>());
112125
}
113126

114127
AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
@@ -120,8 +133,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
120133

121134
// This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
122135
// uses it can be removed by a (separate) dead code elimination pass.
123-
rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.getSource(),
124-
offsets, sizes, strides);
136+
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
137+
op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
125138
return success();
126139
}
127140
};
Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
// RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s
22

3-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
4-
// CHECK: subview %arg0[3, 384] [1, 128] [1, 1]
5-
// CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
3+
// CHECK-LABEL: func.func @subview_strided(
4+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
5+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
6+
// CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
67
%0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
78
%1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
89
return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>>
910
}
1011

1112
// -----
1213

13-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
14-
// CHECK: subview %arg0[3, 673] [1, 10] [1, 1]
15-
// CHECK-SAME: memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
14+
// CHECK-LABEL: func.func @subview_strided(
15+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
16+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
17+
// CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
1618
%0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>>
1719
%1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>>
1820
%2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
@@ -21,27 +23,81 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1
2123

2224
// -----
2325

24-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
25-
// CHECK: [[CST_3:%.*]] = arith.constant 3 : index
26+
// CHECK-LABEL: func.func @subview_strided(
27+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
28+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
29+
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
2630
%cst_1 = arith.constant 1 : index
2731
%cst_2 = arith.constant 2 : index
28-
// CHECK: subview %arg0{{\[}}[[CST_3]], 384] [1, 128] [1, 1]
29-
// CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
32+
// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
3033
%0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
3134
%1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
3235
return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
3336
}
3437

3538
// -----
3639

37-
func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
38-
// CHECK: [[CST_3:%.*]] = arith.constant 3 : index
40+
// CHECK-LABEL: func.func @subview_strided(
41+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
42+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
43+
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
3944
%cst_2 = arith.constant 2 : index
40-
// CHECK: [[CST_384:%.*]] = arith.constant 384 : index
45+
// CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
4146
%cst_128 = arith.constant 128 : index
42-
// CHECK: subview %arg0{{\[}}[[CST_3]], [[CST_384]]] [1, 128] [1, 1]
43-
// CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
47+
// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
4448
%0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
4549
%1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
4650
return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
4751
}
52+
53+
// -----
54+
55+
// CHECK-LABEL: func.func @subview_strided(
56+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
57+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
58+
// CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
59+
%0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>>
60+
%1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
61+
return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>>
62+
}
63+
64+
// -----
65+
66+
// CHECK-LABEL: func.func @subview_strided(
67+
// CHECK-SAME: %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
68+
func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
69+
// CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
70+
%0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>>
71+
%1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>>
72+
%2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>>
73+
return %2 : memref<2x2xf32, strided<[240, 8], offset: 217>>
74+
}
75+
76+
// -----
77+
78+
// CHECK-LABEL: func.func @subview_strided(
79+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
80+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
81+
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
82+
%cst_2 = arith.constant 2 : index
83+
// CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
84+
%cst_64 = arith.constant 64 : index
85+
// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
86+
%0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
87+
%1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
88+
return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
89+
}
90+
91+
// -----
92+
93+
// CHECK-LABEL: func.func @subview_strided(
94+
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
95+
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
96+
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
97+
%cst_1 = arith.constant 1 : index
98+
%cst_2 = arith.constant 2 : index
99+
// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
100+
%0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
101+
%1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
102+
return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
103+
}

0 commit comments

Comments
 (0)