Skip to content

[MLIR] Determine contiguousness of memrefs with dynamic dimensions #142421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Utils/IndexingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ArrayAttr;
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
///
/// `sizes` elements are asserted to be non-negative.
/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
///
/// Return an empty vector if `sizes` is empty.
SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,25 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
///
bool areTrailingDimsContiguous(int64_t n);

/// Return the number of trailing dimensions that are contiguous.
///
/// Examples:
/// - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
/// trailing dimensions is 3
/// - memref<5x3x2xi8, strided<[12,2,1]>>, the number of collapsable
/// trailing dimensions is 2 (dimension 0 is non-contiguous)
/// - memref<5x3x2xi8, strided<[12,4,1]>>, the number of collapsable
/// trailing dimensions is 1 (dimension 1 is non-contiguous)
/// - memref<5x3x2xi8, strided<[12,4,2]>>, the number of collapsable
/// trailing dimensions is 0 (dimension 2 is non-contiguous)
/// - memref<?x3x2xi8, strided<[6,2,1]>>, the number of collapsable
/// trailing dimensions is 3
/// - memref<?x3x2xi8, strided<[12,2,1]>>, the number of collapsable
/// trailing dimensions is 2 (dimension 0 is non-contiguous)
/// - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
/// trailing dimensions is 2 (stride 0 is dynamic)
int64_t getNumContiguousTrailingDims();

/// Return a version of this type with identity layout if it can be
/// determined statically that the layout is the canonical contiguous
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Utils/IndexingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
//===----------------------------------------------------------------------===//

SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
assert((sizes.empty() ||
llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
"sizes must be nonnegative");
int64_t unit = 1;
return ::computeSuffixProductImpl(sizes, unit);
Expand Down
52 changes: 31 additions & 21 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,35 +646,45 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
}

bool MemRefType::areTrailingDimsContiguous(int64_t n) {
if (!isLastDimUnitStride())
return false;
assert(n <= getRank() &&
"number of dimensions to check must not exceed rank");
return n <= getNumContiguousTrailingDims();
}

auto memrefShape = getShape().take_back(n);
if (ShapedType::isDynamicShape(memrefShape))
return false;
int64_t MemRefType::getNumContiguousTrailingDims() {
const int64_t n = getRank();

// memrefs with identity layout are entirely contiguous.
if (getLayout().isIdentity())
return true;
return n;

// Get the strides (if any). Failing to do that, conservatively assume a
// non-contiguous layout.
int64_t offset;
SmallVector<int64_t> stridesFull;
if (!succeeded(getStridesAndOffset(stridesFull, offset)))
return false;
auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);

if (strides.empty())
return true;
SmallVector<int64_t> strides;
if (!succeeded(getStridesAndOffset(strides, offset)))
return 0;

// Check whether strides match "flattened" dims.
SmallVector<int64_t> flattenedDims;
auto dimProduct = 1;
for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
dimProduct *= dim;
flattenedDims.push_back(dimProduct);
ArrayRef<int64_t> shape = getShape();

// A memref with dimensions `d0, d1, ..., dn-1` and strides
// `s0, s1, ..., sn-1` is contiguous up to dimension `k`
// if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
// for `i` in `[k, n-1]`.
// Ignore stride elements if the corresponding dimension is 1, as they are
// of no consequence.
int64_t dimProduct = 1;
for (int64_t i = n - 1; i >= 0; --i) {
if (shape[i] == 1)
continue;
if (strides[i] != dimProduct)
return n - i - 1;
if (shape[i] == ShapedType::kDynamic)
return n - i;
dimProduct *= shape[i];
}

strides = strides.drop_back(1);
return llvm::equal(strides, llvm::reverse(flattenedDims));
return n;
}

MemRefType MemRefType::canonicalizeStridedLayout() {
Expand Down
97 changes: 86 additions & 11 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,35 @@ func.func @transfer_read_leading_dynamic_dims(

// -----

// One of the dims to be flattened is dynamic - not supported ATM.
// The vector is a non-contiguous slice of the input
// memref.

func.func @negative_transfer_read_dynamic_dim_to_flatten(
%mem : memref<4x?x?x2xi8>) -> vector<2x2x2xi8> {

%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
memref<4x?x?x2xi8>, vector<2x2x2xi8>
return %res : vector<2x2x2xi8>
}

// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten(
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast

// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten(
// CHECK-128B-NOT: memref.collapse_shape

// -----

// When collapsing memref dimensions, we may include the rightmost dynamic
// dimension (e.g., at position `k`) provided that the strides for dimensions
// `k+1`, `k+2`, etc., ensure contiguity in memory. The stride at position `k`
// itself does not factor into this. (Here "strides" mean both explicit and
// implied by identity map)

func.func @transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
Expand All @@ -203,11 +229,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
return %res : vector<1x2x6xi32>
}

// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast

// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>

// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
// CHECK-SAME: %[[IDX_1:arg0]]
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>


// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape

// -----
Expand Down Expand Up @@ -451,9 +491,31 @@ func.func @transfer_write_leading_dynamic_dims(

// -----

// One of the dims to be flattened is dynamic - not supported ATM.
// The vector is a non-contiguous slice of the input
// memref.

func.func @negative_transfer_write_dynamic_to_flatten(
%mem : memref<4x?x?x2xi8>,
%vec : vector<2x2x2xi8>) {

%c0 = arith.constant 0 : index
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0]
: vector<2x2x2xi8>, memref<4x?x?x2xi8>
return
}

// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten(
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast

// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten(
// CHECK-128B-NOT: memref.collapse_shape

// -----

// See the comment in front of @transfer_read_dynamic_dim_to_flatten.

func.func @transfer_write_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%vec : vector<1x2x6xi32>,
Expand All @@ -466,11 +528,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
return
}

// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>

// CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
// CHECK-SAME: %[[IDX_1:arg0]]: index
// CHECK-SAME: %[[IDX_2:arg1]]: index
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>

// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
// CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape

// -----
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_mlir_unittest(MLIRIRTests
IRMapping.cpp
InterfaceAttachmentTest.cpp
LocationTest.cpp
MemrefLayoutTest.cpp
OperationSupportTest.cpp
PatternMatchTest.cpp
ShapedTypeTest.cpp
Expand Down
111 changes: 111 additions & 0 deletions mlir/unittests/IR/MemrefLayoutTest.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few high-level comments on test organization and naming:

  • maxContigDim tests getNumContiguousTrailingDims, contigTrailingDim tests areTrailingDimsContiguous, and identityMaps tests both - is there an intended pattern here?
  • The actual "complex" logic seems to reside in getNumContiguousTrailingDims. Would it make sense to focus the tests more narrowly on that hook?
  • Could the test names (maxContigDim, contigTrailingDim, identityMaps) be made more descriptive? Alternatively, adding comments explaining what each test is validating would help - it’s not immediately clear to me.
  • I’d suggest renaming identityMaps to something like noStrides or defaultStrides, since the other tests explicitly include strides. The natural split seems to be "with strides" vs "without strides".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few high-level comments on test organization and naming:

  • maxContigDim tests getNumContiguousTrailingDims, contigTrailingDim tests areTrailingDimsContiguous, and identityMaps tests both - is there an intended pattern here?

maxContigDim is a remnant from when the member function was called getMaxContiguousTrailingDims.
I'll rename it.

'contigTrailingDimstestsareTrailingDimsContiguous`, so it's named after the thing it tests.

identityMaps tests the fastpath when the memref has identify maps

...
int64_t MemRefType::getNumContiguousTrailingDims() {
  const int64_t n = getRank();

  // memrefs with identity layout are entirely contiguous.
  if (getLayout().isIdentity())
    return n;
...

So, yeah, there's a pattern of naming the tests after the thing they test.

  • The actual "complex" logic seems to reside in getNumContiguousTrailingDims. Would it make sense to focus the tests more narrowly on that hook?

I find the tests focused enough.

  • Could the test names (maxContigDim, contigTrailingDim, identityMaps) be made more descriptive? Alternatively, adding comments explaining what each test is validating would help - it’s not immediately clear to me.

What each test is validating is abundantly clear from the EXPECT_ lines.

  • I’d suggest renaming identityMaps to something like noStrides or defaultStrides, since the other tests explicitly include strides. The natural split seems to be "with strides" vs "without strides".

As explained above, it tests exactly the case for having identity maps, hence it's appropriately named.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What each test is validating is abundantly clear from the EXPECT_ lines.

I see where you’re coming from - the EXPECT_ lines do show what is being tested, but they don’t necessarily explain why those inputs are interesting or grouped together.

Please document so that the intent is clear.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • maxContigDim tests getNumContiguousTrailingDims, contigTrailingDim tests areTrailingDimsContiguous, and identityMaps tests both - is there an intended pattern here?

I have either replied or did not understand the question.

  • The actual "complex" logic seems to reside in getNumContiguousTrailingDims. Would it make sense to focus the tests more narrowly on that hook?

I've reworked the test cases to be more "focused".

  • Could the test names (maxContigDim, contigTrailingDim, identityMaps) be made more descriptive? Alternatively, adding comments explaining what each test is validating would help - it’s not immediately clear to me.

Added comments what each test is validating. Also put comments to each individual test case.

  • I’d suggest renaming identityMaps to something like noStrides or defaultStrides, since the other tests explicitly include strides. The natural split seems to be "with strides" vs "without strides".

Test removed and a test case moved.

Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//===- LayoutTest.cpp - unit tests related to memref layout ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "gtest/gtest.h"

using namespace mlir;
using namespace mlir::memref;

//
// Test the correctness of `memref::getNumContiguousTrailingDims`
//
TEST(MemRefLayout, numContigDim) {
MLIRContext ctx;
OpBuilder b(&ctx);

const int64_t _ = ShapedType::kDynamic;
const FloatType f32 = b.getF32Type();
auto strided = [&ctx](ArrayRef<int64_t> s) {
return StridedLayoutAttr::get(&ctx, 0, s);
};

// Special case for identity maps and no explicit `strided` attribute - the
// memref is entirely contiguous even if the strides cannot be determined
// statically.

// memref<?x?x?xf32>
auto m0 = MemRefType::get({_, _, _}, f32);
EXPECT_EQ(m0.getNumContiguousTrailingDims(), 3);

// Conservatively assume memref is sparse everywhere if cannot get the
// strides.

// memref<2x2x2xf32, (i,j,k)->(i,k,j)>
auto m1 = MemRefType::get(
{2, 2, 2}, f32,
AffineMap::getPermutationMap(ArrayRef<int64_t>{0, 2, 1}, &ctx));
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 0);

// A base cases of a fixed memref with the usual strides.

// memref<2x2x2xf32, strided<[4, 2, 1]>>
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
EXPECT_EQ(m3.getNumContiguousTrailingDims(), 3);

// A fixed memref with a discontinuity in the rightmost dimension.

// memref<2x2x2xf32, strided<[8, 4, 2]>>
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);

// A fixed memref with a discontinuity in the "middle".

// memref<2x2x2xf32, strided<[8, 2, 1]>>
auto m5 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
EXPECT_EQ(m5.getNumContiguousTrailingDims(), 2);

// A dynamic memref where the dynamic dimension breaks continuity.

// memref<2x?x2xf32, strided<[4, 2, 1]>>
auto m6 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
EXPECT_EQ(m6.getNumContiguousTrailingDims(), 2);

// A edge case of a dynamic memref where the dynamic dimension is the first
// one.

// memref<?x2x2xf32, strided<[4, 2, 1]>>
auto m7 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);

// A memref with a unit dimension. Unit dimensions do not affect continuity,
// even if the corresponding stride is dynamic.

// memref<2x1x2xf32, strided<[2,?,1]>>
auto m8 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
EXPECT_EQ(m8.getNumContiguousTrailingDims(), 3);
}

//
// Test the member function `memref::areTrailingDimsContiguous`
//
TEST(MemRefLayout, contigTrailingDim) {
MLIRContext ctx;
OpBuilder b(&ctx);

const int64_t _ = ShapedType::kDynamic;
const FloatType f32 = b.getF32Type();
auto strided = [&ctx](ArrayRef<int64_t> s) {
return StridedLayoutAttr::get(&ctx, 0, s);
};

// A not-entirely-continuous, not-entirely-discontinuous memref.
// ensure `areTrailingDimsContiguous` returns `true` for the value
// returned by `getNumContiguousTrailingDims` and `false` for the next bigger
// number.

// memref<2x?x2xf32, strided<[?,2,1]>>
auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
int64_t n = m.getNumContiguousTrailingDims();
EXPECT_TRUE(m.areTrailingDimsContiguous(n));
ASSERT_TRUE(n + 1 <= m.getRank());
EXPECT_FALSE(m.areTrailingDimsContiguous(n + 1));
}
Loading