Skip to content

Commit 1d025f3

Browse files
[MLIR] Determine contiguousness of memrefs with a dynamic dimensions
This patch enhances `MemRefType::areTrailingDimsContiguous` to also handle memrefs with dynamic dimensions. The implementation itself is based on a new member function `MemRefType::getMaxCollapsableTrailingDims` that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontguous or statically unknown strides.
1 parent 01d4b16 commit 1d025f3

File tree

5 files changed

+269
-32
lines changed

5 files changed

+269
-32
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,20 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
838838
///
839839
bool areTrailingDimsContiguous(int64_t n);
840840

841+
/// Return the maximum number of trailing dimensions that can be
842+
/// collapsed.
843+
///
844+
/// Examples:
845+
/// - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
846+
/// trailing dimensions is 0
847+
/// - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
848+
/// trailing dimensions is 3
849+
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
850+
/// collapsable trailing dimensions is 2.
851+
/// - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
852+
/// is 4.
853+
int64_t getMaxCollapsableTrailingDims();
854+
841855
/// Return a version of this type with identity layout if it can be
842856
/// determined statically that the layout is the canonical contiguous
843857
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -646,35 +646,40 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
646646
}
647647

648648
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649-
if (!isLastDimUnitStride())
650-
return false;
649+
return getLayout().isIdentity() ||
650+
getMaxCollapsableTrailingDims() >= std::min(n, getRank());
651+
}
651652

652-
auto memrefShape = getShape().take_back(n);
653-
if (ShapedType::isDynamicShape(memrefShape))
654-
return false;
653+
int64_t MemRefType::getMaxCollapsableTrailingDims() {
654+
const int64_t n = getRank();
655655

656+
// memrefs with identity layout are entirely contiguous.
656657
if (getLayout().isIdentity())
657-
return true;
658+
return n;
658659

660+
// Get the strides (if any). Failing to do that, conservatively assume a
661+
// non-contiguous layout.
659662
int64_t offset;
660-
SmallVector<int64_t> stridesFull;
661-
if (!succeeded(getStridesAndOffset(stridesFull, offset)))
662-
return false;
663-
auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
664-
665-
if (strides.empty())
666-
return true;
663+
SmallVector<int64_t> strides;
664+
if (!succeeded(getStridesAndOffset(strides, offset)))
665+
return 0;
667666

668-
// Check whether strides match "flattened" dims.
669-
SmallVector<int64_t> flattenedDims;
670-
auto dimProduct = 1;
671-
for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
672-
dimProduct *= dim;
673-
flattenedDims.push_back(dimProduct);
667+
auto shape = getShape();
668+
669+
// A memref with dimensions `d0, d1, ..., dn-1` and strides
670+
// `s0, s1, ..., sn-1` is contiguous up to dimension `k`
671+
// if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
672+
// for `i` in `[k, n-1]`.
673+
int64_t dimProduct = 1;
674+
for (int64_t i = n - 1; i >= 0; --i) {
675+
if (strides[i] != dimProduct)
676+
return n - i - 1;
677+
if (shape[i] == ShapedType::kDynamic)
678+
return n - i;
679+
dimProduct *= shape[i];
674680
}
675681

676-
strides = strides.drop_back(1);
677-
return llvm::equal(strides, llvm::reverse(flattenedDims));
682+
return n;
678683
}
679684

680685
MemRefType MemRefType::canonicalizeStridedLayout() {

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func.func @transfer_read_leading_dynamic_dims(
190190

191191
// One of the dims to be flattened is dynamic - not supported ATM.
192192

193-
func.func @negative_transfer_read_dynamic_dim_to_flatten(
193+
func.func @transfer_read_dynamic_dim_to_flatten(
194194
%idx_1: index,
195195
%idx_2: index,
196196
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +203,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
203203
return %res : vector<1x2x6xi32>
204204
}
205205

206-
// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
207-
// CHECK-NOT: memref.collapse_shape
208-
// CHECK-NOT: vector.shape_cast
209-
210-
// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
206+
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
207+
208+
// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
209+
// CHECK-SAME: %[[IDX_1:arg0]]
210+
// CHECK-SAME: %[[IDX_2:arg1]]
211+
// CHECK-SAME: %[[MEM:arg2]]
212+
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
213+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
214+
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
215+
// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
216+
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
217+
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
218+
// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
219+
// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
220+
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
221+
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
222+
223+
224+
// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
211225
// CHECK-128B-NOT: memref.collapse_shape
212226

213227
// -----
@@ -453,7 +467,7 @@ func.func @transfer_write_leading_dynamic_dims(
453467

454468
// One of the dims to be flattened is dynamic - not supported ATM.
455469

456-
func.func @negative_transfer_write_dynamic_to_flatten(
470+
func.func @transfer_write_dynamic_to_flatten(
457471
%idx_1: index,
458472
%idx_2: index,
459473
%vec : vector<1x2x6xi32>,
@@ -466,11 +480,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
466480
return
467481
}
468482

469-
// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
470-
// CHECK-NOT: memref.collapse_shape
471-
// CHECK-NOT: vector.shape_cast
483+
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
484+
485+
// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
486+
// CHECK-SAME: %[[IDX_1:arg0]]: index
487+
// CHECK-SAME: %[[IDX_2:arg1]]: index
488+
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
489+
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
490+
491+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
492+
// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
493+
// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
494+
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
495+
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
496+
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
497+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
498+
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
472499

473-
// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
500+
// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
474501
// CHECK-128B-NOT: memref.collapse_shape
475502

476503
// -----

mlir/unittests/Dialect/MemRef/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_unittest(MLIRMemRefTests
22
InferShapeTest.cpp
3+
LayoutTest.cpp
34
)
45
mlir_target_link_libraries(MLIRMemRefTests
56
PRIVATE
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
10+
#include "mlir/IR/AffineMap.h"
11+
#include "mlir/IR/Builders.h"
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "gtest/gtest.h"
14+
15+
using namespace mlir;
16+
using namespace mlir::memref;
17+
18+
TEST(MemRefLayout, maxCollapseDim) {
19+
MLIRContext ctx;
20+
OpBuilder b(&ctx);
21+
22+
const auto _ = ShapedType::kDynamic;
23+
const auto f32 = b.getF32Type();
24+
auto strided = [&ctx](ArrayRef<int64_t> s) {
25+
return StridedLayoutAttr::get(&ctx, 0, s);
26+
};
27+
28+
// memref<2x2x2xf32, strided<[4,2,1]>
29+
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
30+
EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
31+
32+
// memref<2x2x2xf32, strided<[8,2,1]>
33+
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
34+
EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
35+
36+
// memref<2x2x2xf32, strided<[8,4,1]>
37+
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
38+
EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
39+
40+
// memref<2x2x2xf32, strided<[8,4,2]>
41+
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
42+
EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
43+
44+
// memref<2x2x?xf32, strided<[?,?,1]>
45+
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
46+
EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
47+
48+
// memref<2x2x?xf32, strided<[?,?,2]>
49+
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
50+
EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
51+
52+
// memref<2x?x2xf32, strided<[?,2,1]>
53+
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
54+
EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
55+
56+
// memref<2x?x2xf32, strided<[?,4,1]>
57+
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
58+
EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
59+
60+
// memref<2x?x2xf32, strided<[?,4,2]>
61+
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
62+
EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
63+
64+
// memref<?x2x2xf32, strided<[4,2,1]>
65+
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
66+
EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
67+
68+
// memref<?x2x2xf32, strided<[8,2,1]>
69+
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
70+
EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
71+
72+
// memref<?x2x2xf32, strided<[8,4,1]>
73+
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
74+
EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
75+
76+
// memref<?x2x2xf32, strided<[8,4,2]>
77+
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
78+
EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
79+
}
80+
81+
TEST(MemRefLayout, contigTrailingDim) {
82+
MLIRContext ctx;
83+
OpBuilder b(&ctx);
84+
85+
const auto _ = ShapedType::kDynamic;
86+
const auto f32 = b.getF32Type();
87+
auto strided = [&ctx](ArrayRef<int64_t> s) {
88+
return StridedLayoutAttr::get(&ctx, 0, s);
89+
};
90+
91+
// memref<2x2x2xf32, strided<[4,2,1]>
92+
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
93+
EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
94+
EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
95+
EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
96+
97+
// memref<2x2x2xf32, strided<[8,2,1]>
98+
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
99+
EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
100+
EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
101+
EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
102+
103+
// memref<2x2x2xf32, strided<[8,4,1]>
104+
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
105+
EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
106+
EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
107+
EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
108+
109+
// memref<2x2x2xf32, strided<[8,4,2]>
110+
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
111+
EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
112+
EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
113+
EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
114+
115+
// memref<2x2x?xf32, strided<[?,?,1]>
116+
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
117+
EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
118+
EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
119+
EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
120+
121+
// memref<2x2x?xf32, strided<[?,?,2]>
122+
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
123+
EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
124+
EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
125+
EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
126+
127+
// memref<2x?x2xf32, strided<[?,2,1]>
128+
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
129+
EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
130+
EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
131+
EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
132+
133+
// memref<2x?x2xf32, strided<[?,4,1]>
134+
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
135+
EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
136+
EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
137+
EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
138+
139+
// memref<2x?x2xf32, strided<[?,4,2]>
140+
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
141+
EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
142+
EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
143+
EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
144+
145+
// memref<?x2x2xf32, strided<[4,2,1]>
146+
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
147+
EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
148+
EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
149+
EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
150+
151+
// memref<?x2x2xf32, strided<[8,2,1]>
152+
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
153+
EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
154+
EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
155+
EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
156+
157+
// memref<?x2x2xf32, strided<[8,4,1]>
158+
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
159+
EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
160+
EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
161+
EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
162+
163+
// memref<?x2x2xf32, strided<[8,4,2]>
164+
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
165+
EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
166+
EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
167+
EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
168+
}
169+
170+
TEST(MemRefLayout, identityMaps) {
171+
MLIRContext ctx;
172+
OpBuilder b(&ctx);
173+
174+
const auto _ = ShapedType::kDynamic;
175+
const auto f32 = b.getF32Type();
176+
177+
// memref<2x2x2xf32>
178+
auto m1 = MemRefType::get({2, 2, 2}, f32);
179+
EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
180+
EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
181+
EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
182+
EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
183+
184+
// memref<?x?x?xf32>
185+
auto m2 = MemRefType::get({_, _, _}, f32);
186+
EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
187+
EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
188+
EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
189+
EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
190+
}

0 commit comments

Comments
 (0)