Skip to content

Commit f096e72

Browse files
committed
[mlir] switch bufferization to use strided layout attribute
Bufferization already makes the assumption that buffers pass function boundaries in the strided form and uses the corresponding affine map layouts. Switch it to use the recently introduced strided layout instead to avoid unnecessary casts when bufferizing further operations to the memref dialect counterparts that now largely rely on the strided layout attribute. Depends On D133947 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D133951
1 parent 46b90a7 commit f096e72

16 files changed

+134
-207
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,6 @@ LogicalResult getStridesAndOffset(MemRefType t,
432432
/// `t` with simplified layout.
433433
MemRefType canonicalizeStridedLayout(MemRefType t);
434434

435-
/// Return a version of `t` with a layout that has all dynamic offset and
436-
/// strides. This is used to erase the static layout.
437-
MemRefType eraseStridedLayout(MemRefType t);
438-
439435
/// Given MemRef `sizes` that are either static or dynamic, returns the
440436
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
441437
/// once a dynamic dimension is encountered, all canonical strides become
@@ -462,10 +458,6 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
462458
/// Return true if the layout for `t` is compatible with strided semantics.
463459
bool isStrided(MemRefType t);
464460

465-
/// Return the layout map in strided linear layout AffineMap form.
466-
/// Return null if the layout is not compatible with a strided layout.
467-
AffineMap getStridedLinearLayoutMap(MemRefType t);
468-
469461
} // namespace mlir
470462

471463
#endif // MLIR_IR_BUILTINTYPES_H

mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ namespace mlir {
2525
using namespace mlir;
2626
using namespace mlir::linalg;
2727

28+
static MemRefType makeStridedLayoutDynamic(MemRefType type) {
29+
return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get(
30+
type.getContext(), ShapedType::kDynamicStrideOrOffset,
31+
SmallVector<int64_t>(type.getRank(),
32+
ShapedType::kDynamicStrideOrOffset)));
33+
}
34+
2835
/// Helper function to extract the operand types that are passed to the
2936
/// generated CallOp. MemRefTypes have their layout canonicalized since the
3037
/// information is not used in signature generation.
@@ -37,7 +44,7 @@ static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
3744
// information. Canonicalizing the type at the level of std when going into
3845
// a library call avoids needing to introduce DialectCastOp.
3946
if (auto memrefType = type.dyn_cast<MemRefType>())
40-
result.push_back(eraseStridedLayout(memrefType));
47+
result.push_back(makeStridedLayoutDynamic(memrefType));
4148
else
4249
result.push_back(type);
4350
}
@@ -95,7 +102,7 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
95102
continue;
96103
}
97104
Value cast =
98-
b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op);
105+
b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op);
99106
res.push_back(cast);
100107
}
101108
return res;

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,8 +758,8 @@ bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
758758
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
759759
SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
760760
ShapedType::kDynamicStrideOrOffset);
761-
AffineMap stridedLayout = makeStridedLinearLayoutMap(
762-
dynamicStrides, dynamicOffset, rankedTensorType.getContext());
761+
auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(),
762+
dynamicOffset, dynamicStrides);
763763
return MemRefType::get(rankedTensorType.getShape(),
764764
rankedTensorType.getElementType(), stridedLayout,
765765
memorySpaceAttr);

mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
179179
(aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
180180
return MemRefType::get(
181181
resShape, aT.getElementType(),
182-
makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
182+
StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
183183
}
184184

185185
/// Operates under a scoped context to build the intersection between the

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -992,15 +992,6 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
992992
return simplifyAffineExpr(expr, numDims, nSymbols);
993993
}
994994

995-
/// Return a version of `t` with a layout that has all dynamic offset and
996-
/// strides. This is used to erase the static layout.
997-
MemRefType mlir::eraseStridedLayout(MemRefType t) {
998-
auto val = ShapedType::kDynamicStrideOrOffset;
999-
return MemRefType::Builder(t).setLayout(
1000-
AffineMapAttr::get(makeStridedLinearLayoutMap(
1001-
SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
1002-
}
1003-
1004995
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1005996
MLIRContext *context) {
1006997
SmallVector<AffineExpr, 4> exprs;
@@ -1017,13 +1008,3 @@ bool mlir::isStrided(MemRefType t) {
10171008
auto res = getStridesAndOffset(t, strides, offset);
10181009
return succeeded(res);
10191010
}
1020-
1021-
/// Return the layout map in strided linear layout AffineMap form.
1022-
/// Return null if the layout is not compatible with a strided layout.
1023-
AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1024-
int64_t offset;
1025-
SmallVector<int64_t, 4> strides;
1026-
if (failed(getStridesAndOffset(t, strides, offset)))
1027-
return AffineMap();
1028-
return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1029-
}

mlir/test/Dialect/Arithmetic/one-shot-bufferize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null
1010

1111
// CHECK-LABEL: func @write_to_select_op_source
12-
// CHECK-SAME: %[[t1:.*]]: memref<?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>
12+
// CHECK-SAME: %[[t1:.*]]: memref<?xf32, strided{{.*}}>, %[[t2:.*]]: memref<?xf32, strided{{.*}}>
1313
func.func @write_to_select_op_source(
1414
%t1 : tensor<?xf32> {bufferization.writable = true},
1515
%t2 : tensor<?xf32> {bufferization.writable = true},
@@ -34,7 +34,7 @@ func.func @write_to_select_op_source(
3434
// maps are passed to arith.select. A cast must be inserted.
3535

3636
// CHECK-LABEL: func @write_after_select_read_one
37-
// CHECK-SAME: %[[t1:.*]]: memref<?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>
37+
// CHECK-SAME: %[[t1:.*]]: memref<?xf32, strided{{.*}}>, %[[t2:.*]]: memref<?xf32, strided{{.*}}>
3838
func.func @write_after_select_read_one(
3939
%t1 : tensor<?xf32> {bufferization.writable = true},
4040
%t2 : tensor<?xf32> {bufferization.writable = true},

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func.func @buffer_forwarding_no_conflict(
6363
// -----
6464

6565
// CHECK: func @insertion_point_inside_loop(
66-
// CHECK-SAME: %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index)
66+
// CHECK-SAME: %[[t:.*]]: memref<?xf32, strided{{.*}}>, %[[sz:.*]]: index)
6767
func.func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tensor<?xf32>) {
6868
%c0 = arith.constant 0 : index
6969
%c1 = arith.constant 1 : index
@@ -92,7 +92,7 @@ func.func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tens
9292
// -----
9393

9494
// CHECK: func @insertion_point_outside_loop(
95-
// CHECK-SAME: %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index, %[[idx:.*]]: index)
95+
// CHECK-SAME: %[[t:.*]]: memref<?xf32, strided{{.*}}>, %[[sz:.*]]: index, %[[idx:.*]]: index)
9696
func.func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
9797
%idx : index) -> (tensor<?xf32>) {
9898
%c0 = arith.constant 0 : index

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="dialect-filter=tensor,bufferization allow-unknown-ops allow-return-allocs" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR
1212
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="dialect-filter=scf,bufferization allow-unknown-ops allow-return-allocs" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF
1313

14-
// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
15-
1614
// CHECK-LABEL: func @use_of_unknown_op_1(
1715
// CHECK-SAME: %[[t1:.*]]: tensor<?xf32>
1816
// CHECK-NO-LAYOUT-MAP-LABEL: func @use_of_unknown_op_1(
@@ -27,8 +25,8 @@ func.func @use_of_unknown_op_1(%t1: tensor<?xf32>)
2725

2826
%idx = arith.constant 0 : index
2927
%cst = arith.constant 0.0 : f32
30-
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref<?xf32, #[[$MAP]]>
31-
// CHECK: vector.transfer_read %[[dummy_memref]][%{{.*}}], %{{.*}} : memref<?xf32, #[[$MAP]]>
28+
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref<?xf32, strided<[?], offset: ?>>
29+
// CHECK: vector.transfer_read %[[dummy_memref]][%{{.*}}], %{{.*}} : memref<?xf32, strided<[?], offset: ?>>
3230
// CHECK-NO-LAYOUT-MAP: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref<?xf32>
3331
// CHECK-NO-LAYOUT-MAP: vector.transfer_read %[[dummy_memref]][%{{.*}}], %{{.*}} : memref<?xf32>
3432
%1 = vector.transfer_read %0[%idx], %cst : tensor<?xf32>, vector<5xf32>
@@ -51,8 +49,6 @@ func.func @use_of_unknown_op_2(%t1: tensor<?xf32>) -> tensor<?xf32> {
5149

5250
// -----
5351

54-
// CHECK: #[[$MAP2:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
55-
5652
// CHECK-LABEL: func @use_of_unknown_op_3(
5753
// CHECK-SAME: %[[t1:.*]]: tensor<?xf32>
5854
func.func @use_of_unknown_op_3(%t1: tensor<?xf32>)
@@ -65,7 +61,7 @@ func.func @use_of_unknown_op_3(%t1: tensor<?xf32>)
6561

6662
// CHECK: %[[dummy:.*]] = "test.dummy_op"(%[[t1]])
6763
%0 = "test.dummy_op"(%t1) : (tensor<?xf32>) -> tensor<?xf32>
68-
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref<?xf32, #[[$MAP2]]>
64+
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref<?xf32, strided<[?], offset: ?>>
6965
// CHECK: %[[v2:.*]] = vector.transfer_read %[[dummy_memref]]
7066
%2 = vector.transfer_read %0[%idx], %cst : tensor<?xf32>, vector<5xf32>
7167

@@ -207,7 +203,7 @@ func.func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
207203
func.func @simple_scf_if(%t1: tensor<?xf32> {bufferization.writable = true}, %c: i1, %pos: index, %f: f32)
208204
-> (tensor<?xf32>, index) {
209205
// CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
210-
// CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, #{{.*}}>) {
206+
// CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, strided{{.*}}>) {
211207
%r1, %r2 = scf.if %c -> (tensor<?xf32>, index) {
212208
// CHECK-SCF: scf.yield %[[t1_memref]]
213209
scf.yield %t1, %pos : tensor<?xf32>, index

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,12 @@ func.func @copy_deallocated() -> tensor<10xf32> {
134134
// CHECK-LABEL: func @select_different_tensors(
135135
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
136136
func.func @select_different_tensors(%t: tensor<?xf32>, %sz: index, %c: i1) -> tensor<?xf32> {
137-
// CHECK-DAG: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?xf32, #{{.*}}>
137+
// CHECK-DAG: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?xf32, strided{{.*}}>
138138
// CHECK-DAG: %[[alloc:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<?xf32>
139139
%0 = bufferization.alloc_tensor(%sz) : tensor<?xf32>
140140

141141
// A cast must be inserted because %t and %0 have different memref types.
142-
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<?xf32> to memref<?xf32, #{{.*}}>
142+
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<?xf32> to memref<?xf32, strided{{.*}}>
143143
// CHECK: arith.select %{{.*}}, %[[casted]], %[[m]]
144144
%1 = arith.select %c, %0, %t : tensor<?xf32>
145145
return %1 : tensor<?xf32>

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
// Note: This bufferization is not very efficient yet, but it works.
99

10-
// CHECK: #[[$map1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
1110
// CHECK-LABEL: func @callee(
12-
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, #[[$map1]]>,
13-
// CHECK-SAME: %[[arg1:.*]]: memref<5xf32, #[[$map1]]>) {
11+
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>,
12+
// CHECK-SAME: %[[arg1:.*]]: memref<5xf32, strided<[?], offset: ?>>) {
1413
// This alloc is not needed, but it is inserted due to the out-of-place
1514
// bufferization of the tensor.insert. With a better layering of the out param
1615
// promotion pass, this alloc could be avoided.
@@ -32,9 +31,8 @@
3231
// CHECK-NO-LAYOUT: memref.copy %[[alloc]], %[[arg1]]
3332
// CHECK-NO-LAYOUT: memref.dealloc %[[alloc]]
3433

35-
// CHECK-BASELINE: #[[$map1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
3634
// CHECK-BASELINE-LABEL: func @callee(
37-
// CHECK-BASELINE-SAME: %[[arg0:.*]]: memref<5xf32, #[[$map1]]>) -> memref<5xf32> {
35+
// CHECK-BASELINE-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32> {
3836
// CHECK-BASELINE: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
3937
// CHECK-BASELINE: memref.copy %[[arg0]], %[[alloc]]
4038
// CHECK-BASELINE: memref.store {{.*}}, %[[alloc]]
@@ -49,9 +47,9 @@ func.func @callee(%t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
4947
return %t, %1 : tensor<5xf32>, tensor<5xf32>
5048
}
5149

52-
// CHECK: func @main(%[[arg0:.*]]: memref<5xf32, #[[$map1]]>) -> (f32, f32) {
50+
// CHECK: func @main(%[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> (f32, f32) {
5351
// CHECK: %[[alloc:.*]] = memref.alloc() : memref<5xf32>
54-
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<5xf32> to memref<5xf32, #[[$map1]]>
52+
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
5553
// CHECK: call @callee(%[[arg0]], %[[casted]])
5654
// CHECK: %[[l1:.*]] = memref.load %[[arg0]]
5755
// CHECK: %[[l2:.*]] = memref.load %[[casted]]
@@ -73,10 +71,9 @@ func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
7371

7472
// -----
7573

76-
// CHECK: #[[$map2a:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
7774
// CHECK-LABEL: func @callee(
7875
// CHECK-SAME: %{{.*}}: index,
79-
// CHECK-SAME: %[[r:.*]]: memref<2x5xf32, #[[$map2a]]>) {
76+
// CHECK-SAME: %[[r:.*]]: memref<2x5xf32, strided<[?, ?], offset: ?>>) {
8077
// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10x20xf32>
8178
// CHECK: %[[subview:.*]] = memref.subview %[[alloc]]{{.*}} : memref<10x20xf32> to memref<2x5xf32, strided<[20, 1], offset: ?>>
8279
// CHECK: %[[casted:.*]] = memref.cast %[[subview]]
@@ -110,7 +107,7 @@ func.func @callee(%idx: index) -> tensor<2x5xf32> {
110107

111108
// CHECK: func @main(
112109
// CHECK: %[[alloc:.*]] = memref.alloc() : memref<2x5xf32>
113-
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<2x5xf32> to memref<2x5xf32, #[[$map2a]]>
110+
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<2x5xf32> to memref<2x5xf32, strided<[?, ?], offset: ?>>
114111
// CHECK: call @callee(%{{.*}}, %[[casted]])
115112
// CHECK: memref.load %[[casted]]
116113
// CHECK: memref.dealloc %[[alloc]]

0 commit comments

Comments
 (0)