Skip to content

[MLIR][Linalg] Bail out if the tiles provided are more than the number #66007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 13, 2023

Conversation

chelini
Copy link
Contributor

@chelini chelini commented Sep 11, 2023

Currently, the compiler crashes if the number of tiles provided exceeds the number of loops.

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir

Changes

of loops

Currently, the compiler crashes if the number of tiles provided exceeds the number of loops.

Full diff: https://github.com/llvm/llvm-project/pull/66007.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+6)
  • (modified) mlir/test/Dialect/Linalg/transform-op-tile.mlir (+17)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1b2283c054c7d34..6539641030f905b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2533,6 +2533,12 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
       diag.attachNote(op->getLoc()) << "target op";
       return diag;
     }
+    if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
+      DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                         << "too many tiles for";
+      diag.attachNote(op->getLoc()) << "target op";
+      return diag;
+    }
 
     scf::SCFTilingOptions tilingOptions;
     if (!tileSizes.empty()) {
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index d4629dcb29c3efc..1ed2ff3732fc324 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -220,3 +220,20 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 }
+
+// -----
+
+func.func @matmul(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+                    %arg2: tensor<128x128xf32>) ->  tensor<128x128xf32> {
+  // expected-note @below {{target op}}
+  %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{too many tiles for}}
+  %1, %loops = transform.structured.tile %0 [1, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir-linalg

Changes

of loops

Currently, the compiler crashes if the number of tiles provided exceeds the number of loops.

Full diff: https://github.com/llvm/llvm-project/pull/66007.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+6)
  • (modified) mlir/test/Dialect/Linalg/transform-op-tile.mlir (+17)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1b2283c054c7d34..6539641030f905b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2533,6 +2533,12 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
       diag.attachNote(op->getLoc()) << "target op";
       return diag;
     }
+    if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
+      DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                         << "too many tiles for";
+      diag.attachNote(op->getLoc()) << "target op";
+      return diag;
+    }
 
     scf::SCFTilingOptions tilingOptions;
     if (!tileSizes.empty()) {
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index d4629dcb29c3efc..1ed2ff3732fc324 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -220,3 +220,20 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 }
+
+// -----
+
+func.func @matmul(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+                    %arg2: tensor<128x128xf32>) ->  tensor<128x128xf32> {
+  // expected-note @below {{target op}}
+  %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{too many tiles for}}
+  %1, %loops = transform.structured.tile %0 [1, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}


// -----

func.func @matmul(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using the function name reflecting the purpose of this test case?

Suggested change
func.func @matmul(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
func.func @too_many_tile(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,

@@ -2533,6 +2533,12 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
diag.attachNote(op->getLoc()) << "target op";
return diag;
}
if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "too many tiles for";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rephrase this as too many tile sizes provided, expected at most X, found Y.

of loops

Currently, the compiler crashes if the number of tiles provided exceeds
the number of loops.
@chelini chelini merged commit d65885a into llvm:main Sep 13, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
llvm#66007)

Currently, the compiler crashes if the number of tiles provided exceeds
the number of loops.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants