-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Determine contiguousness of memrefs with dynamic dimensions #142421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-ods Author: Momchil Velikov (momchil-velikov) ChangesThis patch enhances The implementation itself is based on a new member function Full diff: https://github.com/llvm/llvm-project/pull/142421.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..1d12f70882176 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,6 +838,20 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
///
bool areTrailingDimsContiguous(int64_t n);
+ /// Return the maximum number of trailing dimensions that can be
+ /// collapsed.
+ ///
+ /// Examples:
+ /// - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
+ /// trailing dimensions is 0
+ /// - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
+ /// trailing dimensions is 3
+ /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
+ /// collapsable trailing dimensions is 2.
+ /// - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
+ /// is 4.
+ int64_t getMaxCollapsableTrailingDims();
+
/// Return a version of this type with identity layout if it can be
/// determined statically that the layout is the canonical contiguous
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..cc23d08515ff3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,35 +646,40 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
}
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
- if (!isLastDimUnitStride())
- return false;
+ return getLayout().isIdentity() ||
+ getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+}
- auto memrefShape = getShape().take_back(n);
- if (ShapedType::isDynamicShape(memrefShape))
- return false;
+int64_t MemRefType::getMaxCollapsableTrailingDims() {
+ const int64_t n = getRank();
+ // memrefs with identity layout are entirely contiguous.
if (getLayout().isIdentity())
- return true;
+ return n;
+ // Get the strides (if any). Failing to do that, conservatively assume a
+ // non-contiguous layout.
int64_t offset;
- SmallVector<int64_t> stridesFull;
- if (!succeeded(getStridesAndOffset(stridesFull, offset)))
- return false;
- auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
- if (strides.empty())
- return true;
+ SmallVector<int64_t> strides;
+ if (!succeeded(getStridesAndOffset(strides, offset)))
+ return 0;
- // Check whether strides match "flattened" dims.
- SmallVector<int64_t> flattenedDims;
- auto dimProduct = 1;
- for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
- dimProduct *= dim;
- flattenedDims.push_back(dimProduct);
+ auto shape = getShape();
+
+ // A memref with dimensions `d0, d1, ..., dn-1` and strides
+ // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+ // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+ // for `i` in `[k, n-1]`.
+ int64_t dimProduct = 1;
+ for (int64_t i = n - 1; i >= 0; --i) {
+ if (strides[i] != dimProduct)
+ return n - i - 1;
+ if (shape[i] == ShapedType::kDynamic)
+ return n - i;
+ dimProduct *= shape[i];
}
- strides = strides.drop_back(1);
- return llvm::equal(strides, llvm::reverse(flattenedDims));
+ return n;
}
MemRefType MemRefType::canonicalizeStridedLayout() {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..5b2f2ab1f2cef 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -190,7 +190,7 @@ func.func @transfer_read_leading_dynamic_dims(
// One of the dims to be flattened is dynamic - not supported ATM.
-func.func @negative_transfer_read_dynamic_dim_to_flatten(
+func.func @transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +203,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
return %res : vector<1x2x6xi32>
}
-// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]
+// CHECK-SAME: %[[IDX_2:arg1]]
+// CHECK-SAME: %[[MEM:arg2]]
+// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
+
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
@@ -453,7 +467,7 @@ func.func @transfer_write_leading_dynamic_dims(
// One of the dims to be flattened is dynamic - not supported ATM.
-func.func @negative_transfer_write_dynamic_to_flatten(
+func.func @transfer_write_dynamic_to_flatten(
%idx_1: index,
%idx_2: index,
%vec : vector<1x2x6xi32>,
@@ -466,11 +480,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
return
}
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]: index
+// CHECK-SAME: %[[IDX_2:arg1]]: index
+// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
+// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..1f6df1024f430 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRMemRefTests
InferShapeTest.cpp
+ LayoutTest.cpp
)
mlir_target_link_libraries(MLIRMemRefTests
PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
new file mode 100644
index 0000000000000..e01c0056d5cec
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -0,0 +1,190 @@
+//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(MemRefLayout, maxCollapseDim) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+ auto strided = [&ctx](ArrayRef<int64_t> s) {
+ return StridedLayoutAttr::get(&ctx, 0, s);
+ };
+
+ // memref<2x2x2xf32, strided<[4,2,1]>
+ auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+
+ // memref<2x2x2xf32, strided<[8,2,1]>
+ auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<2x2x2xf32, strided<[8,4,1]>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x2x2xf32, strided<[8,4,2]>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<2x2x?xf32, strided<[?,?,1]>
+ auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+ EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x2x?xf32, strided<[?,?,2]>
+ auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+ EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<2x?x2xf32, strided<[?,2,1]>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<2x?x2xf32, strided<[?,4,1]>
+ auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+ EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x?x2xf32, strided<[?,4,2]>
+ auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+ EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<?x2x2xf32, strided<[4,2,1]>
+ auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
+
+ // memref<?x2x2xf32, strided<[8,2,1]>
+ auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<?x2x2xf32, strided<[8,4,1]>
+ auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<?x2x2xf32, strided<[8,4,2]>
+ auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
+}
+
+TEST(MemRefLayout, contigTrailingDim) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+ auto strided = [&ctx](ArrayRef<int64_t> s) {
+ return StridedLayoutAttr::get(&ctx, 0, s);
+ };
+
+ // memref<2x2x2xf32, strided<[4,2,1]>
+ auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,2,1]>
+ auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,4,1]>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,4,2]>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
+
+ // memref<2x2x?xf32, strided<[?,?,1]>
+ auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+ EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
+
+ // memref<2x2x?xf32, strided<[?,?,2]>
+ auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,2,1]>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,4,1]>
+ auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+ EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,4,2]>
+ auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[4,2,1]>
+ auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,2,1]>
+ auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,4,1]>
+ auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,4,2]>
+ auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
+}
+
+TEST(MemRefLayout, identityMaps) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+
+ // memref<2x2x2xf32>
+ auto m1 = MemRefType::get({2, 2, 2}, f32);
+ EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+ // memref<?x?x?xf32>
+ auto m2 = MemRefType::get({_, _, _}, f32);
+ EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks good! I like the overall design, especially the new c++ testing approach. My comments are all minor.
1d025f3
to
1fe6866
Compare
c97c6c1
to
e3b6e21
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank for for adding the stride = 1 cases, looks good to me! I've added a few minor comments.
e3b6e21
to
ab4681a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Maybe allow others with more experience here than me some time to review
ab4681a
to
1cee446
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually don't write such unit tests when there exists a way to check functionality via a pass. Individual unit tests create a lot of additional compiler build time. In this case, one could add a test pass that annotates the operation with an attribute containing the number of contiguous dimensions on the result type, for example.
Also, I see this is carried over from a pre-existing structure, but this doesn't belong to unittest/Dialect/MemRef
. The memref type is in the builtin dialect, so the related tests should live in unittest/IR
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a pass like this one which tests the another method of memref:
llvm::outs() << "\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I consider those test-
passes an abuse of both the pass pipeline and the testing infrastructure.
Of course, if I get an unambiguous statement "this PR is not going in with those unit tests", I'll do it (with great displeasure).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll do it
Or rather, resurrect it, I had that pass earlier that did
/// Traverse AllocOp and compute the maximum number of contiguous trailing dimensions.
void TestMemRefContigCalculation::runOnOperation() {
getOperation().walk([&](func::CallOp callOp) {
auto memrefType = cast<MemRefType>(callOp.getResult(0).getType());
memrefType.print(llvm::outs());
int64_t n = memrefType.getMaxContiguousTrailingDims();
llvm::outs() << ": " << n << " contiguous trailing dims(s)\n";
});
llvm::outs().flush();
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The google tests added here are nicer tests. But they increase compile time more than the pass tests. How much nicer, and how much slower? I don't know, one's preference depends on how much weight each is given. Third dimension is consistency with other tests, but there it's a tie as both exist for memref type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'm fine with either)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me chime in briefly.
We usually don't write such unit tests when there exists a way to check functionality via a pass.
Just to clarify: there isn’t a way to test getNumContiguousTrailingDims
in isolation today. That said, we do have transformations that use it.
Question 1: Is it valuable to test getNumContiguousTrailingDims
in isolation?
I think so - it checks a fairly fundamental property, and the test effectively documents what qualifies as a “contiguous memref slice.”
Question 2: How do we test it?
There are two main options: unit tests or a test pass (possibly with a TD op):
- Unit tests involve less boilerplate but increase compile times.
- A test pass avoids compile-time overhead but requires more setup. Runtime cost is unclear - though honestly, we’re hand-waving a bit here without measurements.
Personally, I don’t have a strong preference. I’ve tried to lay out the trade-offs, and I think either direction is valid. In cases like this, I’d suggest just sticking with the approach that’s already implemented.
@ftynse, I read your comment more as a suggestion than a hard requirement - please let us know if we misread that. (As I said, I’m happy with either option.)
This patch enhances `MemRefType::areTrailingDimsContiguous` to also handle memrefs with dynamic dimensions. The implementation itself is based on a new member function `MemRefType::getMaxCollapsableTrailingDims` that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontguous or statically unknown strides.
`computeStrides` does not acccess the first element of `sizes`
- rename `getMaxCollapsabelTrailingDims` to `getMaxContiguousTrailingDims` - new set of examples - remove redundant call to `isIdentify()` - make sure a variable type is visible on the declaration line - some micro-optimisation
d3458a6
to
20b495d
Compare
Just a drive-by comment - when I started contributing here, a reviewer kindly pointed out to me in this #135096 (comment) that force pushing isn't recommended https://llvm.org/docs/GitHub.html#rebasing-pull-requests-and-force-pushes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG % some suggestion re structuring the unit tests.
I suggest leaving ~24hrs for Alex to comment whether he minds you leaving the unit tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few high-level comments on test organization and naming:
maxContigDim
testsgetNumContiguousTrailingDims
,contigTrailingDim
testsareTrailingDimsContiguous
, andidentityMaps
tests both - is there an intended pattern here?- The actual "complex" logic seems to reside in
getNumContiguousTrailingDims
. Would it make sense to focus the tests more narrowly on that hook? - Could the test names (
maxContigDim
,contigTrailingDim
,identityMaps
) be made more descriptive? Alternatively, adding comments explaining what each test is validating would help - it’s not immediately clear to me. - I’d suggest renaming
identityMaps
to something likenoStrides
ordefaultStrides
, since the other tests explicitly include strides. The natural split seems to be "with strides" vs "without strides".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few high-level comments on test organization and naming:
maxContigDim
testsgetNumContiguousTrailingDims
,contigTrailingDim
testsareTrailingDimsContiguous
, andidentityMaps
tests both - is there an intended pattern here?
maxContigDim
is a remnant from when the member function was called getMaxContiguousTrailingDims
.
I'll rename it.
'contigTrailingDimstests
areTrailingDimsContiguous`, so it's named after the thing it tests.
identityMaps
tests the fastpath when the memref has identify maps
...
int64_t MemRefType::getNumContiguousTrailingDims() {
const int64_t n = getRank();
// memrefs with identity layout are entirely contiguous.
if (getLayout().isIdentity())
return n;
...
So, yeah, there's a pattern of naming the tests after the thing they test.
- The actual "complex" logic seems to reside in
getNumContiguousTrailingDims
. Would it make sense to focus the tests more narrowly on that hook?
I find the tests focused enough.
- Could the test names (
maxContigDim
,contigTrailingDim
,identityMaps
) be made more descriptive? Alternatively, adding comments explaining what each test is validating would help - it’s not immediately clear to me.
What each test is validating is abundantly clear from the EXPECT_
lines.
- I’d suggest renaming
identityMaps
to something likenoStrides
ordefaultStrides
, since the other tests explicitly include strides. The natural split seems to be "with strides" vs "without strides".
As explained above, it tests exactly the case for having identity maps, hence it's appropriately named.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What each test is validating is abundantly clear from the EXPECT_ lines.
I see where you’re coming from - the EXPECT_
lines do show what is being tested, but they don’t necessarily explain why those inputs are interesting or grouped together.
Please document so that the intent is clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maxContigDim
testsgetNumContiguousTrailingDims
,contigTrailingDim
testsareTrailingDimsContiguous
, andidentityMaps
tests both - is there an intended pattern here?
I have either replied or did not understand the question.
- The actual "complex" logic seems to reside in
getNumContiguousTrailingDims
. Would it make sense to focus the tests more narrowly on that hook?
I've reworked the test cases to be more "focused".
- Could the test names (
maxContigDim
,contigTrailingDim
,identityMaps
) be made more descriptive? Alternatively, adding comments explaining what each test is validating would help - it’s not immediately clear to me.
Added comments what each test is validating. Also put comments to each individual test case.
- I’d suggest renaming
identityMaps
to something likenoStrides
ordefaultStrides
, since the other tests explicitly include strides. The natural split seems to be "with strides" vs "without strides".
Test removed and a test case moved.
// memref<2x2x2xf32, strided<[4,2,1]> | ||
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1})); | ||
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3); | ||
|
||
// memref<2x2x2xf32, strided<[8,2,1]> | ||
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1})); | ||
EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2); | ||
|
||
// memref<2x2x2xf32, strided<[8,4,1]> | ||
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1})); | ||
EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1); | ||
|
||
// memref<2x2x2xf32, strided<[8,4,2]> | ||
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2})); | ||
EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice variation of the initial example, memref<2x2x2xf32, strided<[4,2,1]>
, where modifying the strides alters the number of contiguous dims. However, the examples that follow (with dynamic shapes) feel arbitrary. Could you group them somehow, either with block comments or by splitting into multiple test functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not arbitrary, they are systematically created. Splitting into multiple test functions sounds arbitrary and unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are systematically created
The “system” you used isn’t obvious to me :)
Could you add a brief comment to document it? That would really help both current and future contributors. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The system was described in the (now outdated) commit 145b055
// The vector could be a non-contiguous slice of the input | ||
// memref. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why "could"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you objecting about the word "could"? Perhaps you prefer "is" ?
|
||
// ----- | ||
|
||
// Can flatten the righmost dynamic dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Because the corresponding vector dim is 1? Could you expand the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As everywhere, can flatten because read/written area is contiguous. This is the whole premise of the transformation and the test files, I didn't think it needed more explanation.
Such an explanation of program logic (or in this case, a theorem "from these and these properties it follows the area is contiguous") is in appropriate in a test file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more context, the comment and the test were originally
// One of the dims to be flattened is dynamic - not supported ATM.
func.func @negative_transfer_read_dynamic_dim_to_flatten(
...
With this patch the test is no longer negative and the comment reflects that
at the same level of detail as the original comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such an explanation of program logic (or in this case, a theorem "from these and these properties it follows the area is contiguous") is in appropriate in a test file.
We do document both what and why (as we see appropriate) in these tests, as covered by MLIR's TestingGuide:
We may disagree on whether this case requires documentation (I believe that a comment would be beneficial), but in such case, I recommend removing the comment altogether, it adds no new info (it repeats what test function documents).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Because the corresponding vector dim is 1? Could you expand the comment?
Added answer "why"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me chime in briefly.
We usually don't write such unit tests when there exists a way to check functionality via a pass.
Just to clarify: there isn’t a way to test getNumContiguousTrailingDims
in isolation today. That said, we do have transformations that use it.
Question 1: Is it valuable to test getNumContiguousTrailingDims
in isolation?
I think so - it checks a fairly fundamental property, and the test effectively documents what qualifies as a “contiguous memref slice.”
Question 2: How do we test it?
There are two main options: unit tests or a test pass (possibly with a TD op):
- Unit tests involve less boilerplate but increase compile times.
- A test pass avoids compile-time overhead but requires more setup. Runtime cost is unclear - though honestly, we’re hand-waving a bit here without measurements.
Personally, I don’t have a strong preference. I’ve tried to lay out the trade-offs, and I think either direction is valid. In cases like this, I’d suggest just sticking with the approach that’s already implemented.
@ftynse, I read your comment more as a suggestion than a hard requirement - please let us know if we misread that. (As I said, I’m happy with either option.)
+1 That said, that document also states:
I do occasionally rebase to keep my branch up-to-date with main, especially for longer-lived PRs that are still under review. I'm always happy to adapt that habit, of course! For me, the real concern is around “rebase + squash”, since that can wipe out intermediate commits that were already reviewed or discussed - and that history can be valuable. Just posting this as a side note to clarify my current workflow - again, very open to adjusting it if needed. |
Sometime a rebase is not a matter of preference, but unavoidable, for example if you have stacked PRs and add a fixup to a dependent branch (not that case here). |
The latest updates don’t address all of my concerns - I’ve “unresolved” the corresponding threads to keep the discussion visible. Broadly, I agree with Alex’s point that we should test functionality like this via transformations rather than unit tests. From the LLVM Testing Guide:
That said, we should be mindful that this takes a somewhat non-canonical approach to testing in MLIR. With that in mind, I’d suggest trying to follow existing conventions as closely as possible:
We don’t have formal documentation for MLIR unit test conventions, so this is based on experience and precedent. I'm happy for other reviewers to suggest a different direction for this PR. From my side, the priority is to unblock this while:
Thank you all! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the updates, the latest revision addresses my comment, LGTM!
From what I can tell, all comments have been addressed, but please wait till tomorrow before merging. Just in case there's more comments (also being mindful of the time difference between contributors).
Thanks for seeing this through!
…lvm#142421) This patch enhances `MemRefType::areTrailingDimsContiguous` to also handle memrefs with dynamic dimensions. The implementation itself is based on a new member function `MemRefType::getMaxCollapsableTrailingDims` that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontiguous or statically unknown strides.
…lvm#142421) This patch enhances `MemRefType::areTrailingDimsContiguous` to also handle memrefs with dynamic dimensions. The implementation itself is based on a new member function `MemRefType::getMaxCollapsableTrailingDims` that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontiguous or statically unknown strides.
This patch enhances
MemRefType::areTrailingDimsContiguous
to also handle memrefs with dynamic dimensions.The implementation itself is based on a new member function
MemRefType::getMaxCollapsableTrailingDims
that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontiguous or statically unknown strides.(see also #142422)