Skip to content

Commit a94b304

Browse files
momchil-velikovJaddyen
authored andcommitted
[MLIR] Determine contiguousness of memrefs with dynamic dimensions (llvm#142421)
This patch enhances `MemRefType::areTrailingDimsContiguous` to also handle memrefs with dynamic dimensions. The implementation itself is based on a new member function `MemRefType::getMaxCollapsableTrailingDims` that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontiguous or statically unknown strides.
1 parent 8ec0137 commit a94b304

File tree

7 files changed

+251
-34
lines changed

7 files changed

+251
-34
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ArrayAttr;
4040
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
4141
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
4242
///
43-
/// `sizes` elements are asserted to be non-negative.
43+
/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
4444
///
4545
/// Return an empty vector if `sizes` is empty.
4646
SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,25 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
839839
///
840840
bool areTrailingDimsContiguous(int64_t n);
841841

842+
/// Return the number of trailing dimensions that are contiguous.
843+
///
844+
/// Examples:
845+
/// - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
846+
/// trailing dimensions is 3
847+
/// - memref<5x3x2xi8, strided<[12,2,1]>>, the number of collapsable
848+
/// trailing dimensions is 2 (dimension 0 is non-contiguous)
849+
/// - memref<5x3x2xi8, strided<[12,4,1]>>, the number of collapsable
850+
/// trailing dimensions is 1 (dimension 1 is non-contiguous)
851+
/// - memref<5x3x2xi8, strided<[12,4,2]>>, the number of collapsable
852+
/// trailing dimensions is 0 (dimension 2 is non-contiguous)
853+
/// - memref<?x3x2xi8, strided<[6,2,1]>>, the number of collapsable
854+
/// trailing dimensions is 3
855+
/// - memref<?x3x2xi8, strided<[12,2,1]>>, the number of collapsable
856+
/// trailing dimensions is 2 (dimension 0 is non-contiguous)
857+
/// - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
858+
/// trailing dimensions is 2 (stride 0 is dynamic)
859+
int64_t getNumContiguousTrailingDims();
860+
842861
/// Return a version of this type with identity layout if it can be
843862
/// determined statically that the layout is the canonical contiguous
844863
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
6969
//===----------------------------------------------------------------------===//
7070

7171
SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
72-
assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
72+
assert((sizes.empty() ||
73+
llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
7374
"sizes must be nonnegative");
7475
int64_t unit = 1;
7576
return ::computeSuffixProductImpl(sizes, unit);

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -660,35 +660,45 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
660660
}
661661

662662
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
663-
if (!isLastDimUnitStride())
664-
return false;
663+
assert(n <= getRank() &&
664+
"number of dimensions to check must not exceed rank");
665+
return n <= getNumContiguousTrailingDims();
666+
}
665667

666-
auto memrefShape = getShape().take_back(n);
667-
if (ShapedType::isDynamicShape(memrefShape))
668-
return false;
668+
int64_t MemRefType::getNumContiguousTrailingDims() {
669+
const int64_t n = getRank();
669670

671+
// memrefs with identity layout are entirely contiguous.
670672
if (getLayout().isIdentity())
671-
return true;
673+
return n;
672674

675+
// Get the strides (if any). Failing to do that, conservatively assume a
676+
// non-contiguous layout.
673677
int64_t offset;
674-
SmallVector<int64_t> stridesFull;
675-
if (!succeeded(getStridesAndOffset(stridesFull, offset)))
676-
return false;
677-
auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
678-
679-
if (strides.empty())
680-
return true;
678+
SmallVector<int64_t> strides;
679+
if (!succeeded(getStridesAndOffset(strides, offset)))
680+
return 0;
681681

682-
// Check whether strides match "flattened" dims.
683-
SmallVector<int64_t> flattenedDims;
684-
auto dimProduct = 1;
685-
for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
686-
dimProduct *= dim;
687-
flattenedDims.push_back(dimProduct);
682+
ArrayRef<int64_t> shape = getShape();
683+
684+
// A memref with dimensions `d0, d1, ..., dn-1` and strides
685+
// `s0, s1, ..., sn-1` is contiguous up to dimension `k`
686+
// if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
687+
// for `i` in `[k, n-1]`.
688+
// Ignore stride elements if the corresponding dimension is 1, as they are
689+
// of no consequence.
690+
int64_t dimProduct = 1;
691+
for (int64_t i = n - 1; i >= 0; --i) {
692+
if (shape[i] == 1)
693+
continue;
694+
if (strides[i] != dimProduct)
695+
return n - i - 1;
696+
if (shape[i] == ShapedType::kDynamic)
697+
return n - i;
698+
dimProduct *= shape[i];
688699
}
689700

690-
strides = strides.drop_back(1);
691-
return llvm::equal(strides, llvm::reverse(flattenedDims));
701+
return n;
692702
}
693703

694704
MemRefType MemRefType::canonicalizeStridedLayout() {

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,35 @@ func.func @transfer_read_leading_dynamic_dims(
188188

189189
// -----
190190

191-
// One of the dims to be flattened is dynamic - not supported ATM.
191+
// The vector is a non-contiguous slice of the input
192+
// memref.
192193

193194
func.func @negative_transfer_read_dynamic_dim_to_flatten(
195+
%mem : memref<4x?x?x2xi8>) -> vector<2x2x2xi8> {
196+
197+
%c0 = arith.constant 0 : index
198+
%cst = arith.constant 0 : i8
199+
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
200+
memref<4x?x?x2xi8>, vector<2x2x2xi8>
201+
return %res : vector<2x2x2xi8>
202+
}
203+
204+
// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten(
205+
// CHECK-NOT: memref.collapse_shape
206+
// CHECK-NOT: vector.shape_cast
207+
208+
// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten(
209+
// CHECK-128B-NOT: memref.collapse_shape
210+
211+
// -----
212+
213+
// When collapsing memref dimensions, we may include the rightmost dynamic
214+
// dimension (e.g., at position `k`) provided that the strides for dimensions
215+
// `k+1`, `k+2`, etc., ensure contiguity in memory. The stride at position `k`
216+
// itself does not factor into this. (Here "strides" mean both explicit and
217+
// implied by identity map)
218+
219+
func.func @transfer_read_dynamic_dim_to_flatten(
194220
%idx_1: index,
195221
%idx_2: index,
196222
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +229,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
203229
return %res : vector<1x2x6xi32>
204230
}
205231

206-
// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
207-
// CHECK-NOT: memref.collapse_shape
208-
// CHECK-NOT: vector.shape_cast
209-
210-
// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
232+
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
233+
234+
// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
235+
// CHECK-SAME: %[[IDX_1:arg0]]
236+
// CHECK-SAME: %[[IDX_2:arg1]]
237+
// CHECK-SAME: %[[MEM:arg2]]
238+
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
239+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
240+
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
241+
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
242+
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
243+
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
244+
// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
245+
// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
246+
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
247+
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
248+
249+
250+
// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
211251
// CHECK-128B-NOT: memref.collapse_shape
212252

213253
// -----
@@ -451,9 +491,31 @@ func.func @transfer_write_leading_dynamic_dims(
451491

452492
// -----
453493

454-
// One of the dims to be flattened is dynamic - not supported ATM.
494+
// The vector is a non-contiguous slice of the input
495+
// memref.
455496

456497
func.func @negative_transfer_write_dynamic_to_flatten(
498+
%mem : memref<4x?x?x2xi8>,
499+
%vec : vector<2x2x2xi8>) {
500+
501+
%c0 = arith.constant 0 : index
502+
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0]
503+
: vector<2x2x2xi8>, memref<4x?x?x2xi8>
504+
return
505+
}
506+
507+
// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten(
508+
// CHECK-NOT: memref.collapse_shape
509+
// CHECK-NOT: vector.shape_cast
510+
511+
// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten(
512+
// CHECK-128B-NOT: memref.collapse_shape
513+
514+
// -----
515+
516+
// See the comment in front of @transfer_read_dynamic_dim_to_flatten.
517+
518+
func.func @transfer_write_dynamic_dim_to_flatten(
457519
%idx_1: index,
458520
%idx_2: index,
459521
%vec : vector<1x2x6xi32>,
@@ -466,11 +528,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
466528
return
467529
}
468530

469-
// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
470-
// CHECK-NOT: memref.collapse_shape
471-
// CHECK-NOT: vector.shape_cast
531+
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
532+
533+
// CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
534+
// CHECK-SAME: %[[IDX_1:arg0]]: index
535+
// CHECK-SAME: %[[IDX_2:arg1]]: index
536+
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
537+
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
538+
539+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
540+
// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
541+
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
542+
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
543+
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
544+
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
545+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
546+
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
472547

473-
// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
548+
// CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
474549
// CHECK-128B-NOT: memref.collapse_shape
475550

476551
// -----

mlir/unittests/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_unittest(MLIRIRTests
1010
IRMapping.cpp
1111
InterfaceAttachmentTest.cpp
1212
LocationTest.cpp
13+
MemrefLayoutTest.cpp
1314
OperationSupportTest.cpp
1415
PatternMatchTest.cpp
1516
ShapedTypeTest.cpp
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//===- LayoutTest.cpp - unit tests related to memref layout ---------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
10+
#include "mlir/IR/AffineMap.h"
11+
#include "mlir/IR/Builders.h"
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "gtest/gtest.h"
14+
15+
using namespace mlir;
16+
using namespace mlir::memref;
17+
18+
//
19+
// Test the correctness of `memref::getNumContiguousTrailingDims`
20+
//
21+
TEST(MemRefLayout, numContigDim) {
22+
MLIRContext ctx;
23+
OpBuilder b(&ctx);
24+
25+
const int64_t _ = ShapedType::kDynamic;
26+
const FloatType f32 = b.getF32Type();
27+
auto strided = [&ctx](ArrayRef<int64_t> s) {
28+
return StridedLayoutAttr::get(&ctx, 0, s);
29+
};
30+
31+
// Special case for identity maps and no explicit `strided` attribute - the
32+
// memref is entirely contiguous even if the strides cannot be determined
33+
// statically.
34+
35+
// memref<?x?x?xf32>
36+
auto m0 = MemRefType::get({_, _, _}, f32);
37+
EXPECT_EQ(m0.getNumContiguousTrailingDims(), 3);
38+
39+
// Conservatively assume memref is sparse everywhere if cannot get the
40+
// strides.
41+
42+
// memref<2x2x2xf32, (i,j,k)->(i,k,j)>
43+
auto m1 = MemRefType::get(
44+
{2, 2, 2}, f32,
45+
AffineMap::getPermutationMap(ArrayRef<int64_t>{0, 2, 1}, &ctx));
46+
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 0);
47+
48+
// A base cases of a fixed memref with the usual strides.
49+
50+
// memref<2x2x2xf32, strided<[4, 2, 1]>>
51+
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
52+
EXPECT_EQ(m3.getNumContiguousTrailingDims(), 3);
53+
54+
// A fixed memref with a discontinuity in the rightmost dimension.
55+
56+
// memref<2x2x2xf32, strided<[8, 4, 2]>>
57+
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
58+
EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
59+
60+
// A fixed memref with a discontinuity in the "middle".
61+
62+
// memref<2x2x2xf32, strided<[8, 2, 1]>>
63+
auto m5 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
64+
EXPECT_EQ(m5.getNumContiguousTrailingDims(), 2);
65+
66+
// A dynamic memref where the dynamic dimension breaks continuity.
67+
68+
// memref<2x?x2xf32, strided<[4, 2, 1]>>
69+
auto m6 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
70+
EXPECT_EQ(m6.getNumContiguousTrailingDims(), 2);
71+
72+
// A edge case of a dynamic memref where the dynamic dimension is the first
73+
// one.
74+
75+
// memref<?x2x2xf32, strided<[4, 2, 1]>>
76+
auto m7 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
77+
EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
78+
79+
// A memref with a unit dimension. Unit dimensions do not affect continuity,
80+
// even if the corresponding stride is dynamic.
81+
82+
// memref<2x1x2xf32, strided<[2,?,1]>>
83+
auto m8 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
84+
EXPECT_EQ(m8.getNumContiguousTrailingDims(), 3);
85+
}
86+
87+
//
88+
// Test the member function `memref::areTrailingDimsContiguous`
89+
//
90+
TEST(MemRefLayout, contigTrailingDim) {
91+
MLIRContext ctx;
92+
OpBuilder b(&ctx);
93+
94+
const int64_t _ = ShapedType::kDynamic;
95+
const FloatType f32 = b.getF32Type();
96+
auto strided = [&ctx](ArrayRef<int64_t> s) {
97+
return StridedLayoutAttr::get(&ctx, 0, s);
98+
};
99+
100+
// A not-entirely-continuous, not-entirely-discontinuous memref.
101+
// ensure `areTrailingDimsContiguous` returns `true` for the value
102+
// returned by `getNumContiguousTrailingDims` and `false` for the next bigger
103+
// number.
104+
105+
// memref<2x?x2xf32, strided<[?,2,1]>>
106+
auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
107+
int64_t n = m.getNumContiguousTrailingDims();
108+
EXPECT_TRUE(m.areTrailingDimsContiguous(n));
109+
ASSERT_TRUE(n + 1 <= m.getRank());
110+
EXPECT_FALSE(m.areTrailingDimsContiguous(n + 1));
111+
}

0 commit comments

Comments
 (0)