Skip to content

[mlir] Fix crash when verifying linalg.transpose #131733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2025

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Mar 18, 2025

Adds checks in isPermutationVector for indices that are outside of the bounds and removes the assert.

@llvmbot
Copy link
Member

llvmbot commented Mar 18, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Ian Wood (IanWood1)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/131733.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+2-2)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+22)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pack.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 108839a4d90e9..d9edabef6693d 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -220,10 +220,10 @@ bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) {
 }
 
 bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
-  assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
-         "permutation must be non-negative");
   llvm::SmallDenseSet<int64_t, 4> seenVals;
   for (auto val : interchange) {
+    if (val < 0 || static_cast<uint64_t>(val) >= interchange.size())
+      return false;
     if (seenVals.count(val))
       return false;
     seenVals.insert(val);
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index f2283db8f89b2..90ceadebbc1fa 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -949,6 +949,28 @@ func.func @transpose_invalid_permutation(%input: tensor<16x32x64xf32>,
 
 // -----
 
+func.func @transpose_out_of_range_permutation(%input: tensor<16x32x64xf32>,
+    %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  // expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [1, 2, 3]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+
+// -----
+
+func.func @transpose_negative_permutation(%input: tensor<16x32x64xf32>,
+    %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  // expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [1, 2, -1]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+// -----
 func.func @transpose_permutated_dims_mismatch(%input: tensor<16x32x64xf32>,
     %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
   // expected-error @+1 {{'linalg.transpose' op dim(result, 0) = 32 doesn't match dim(input, permutation[0]) = 16}}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
index b3ad73e8df8e7..620a21896b0c1 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
@@ -595,7 +595,7 @@ module attributes {transform.with_named_sequence} {
       %unpack = transform.get_consumers_of_result %1[0]
         : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.unpack">)
       %2, %pack_2, %unpack_2 =
-        // expected-error @below {{invalid outer_perm}}
+        // expected-error @below {{"outer_perm" is not a valid permutation}}
         transform.structured.pack_transpose %unpack with_compute_op(%1)
         outer_perm = [1]
         : (!transform.op<"linalg.unpack">, !transform.op<"linalg.generic">)
@@ -623,7 +623,7 @@ module attributes {transform.with_named_sequence} {
       %unpack = transform.get_consumers_of_result %1[0]
         : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.unpack">)
       %2, %pack_2, %unpack_2 =
-        // expected-error @below {{invalid inner_perm}}
+        // expected-error @below {{"inner_perm" is not a valid permutation}}
         transform.structured.pack_transpose %unpack with_compute_op(%1)
         inner_perm = [1]
         : (!transform.op<"linalg.unpack">, !transform.op<"linalg.generic">)

@llvmbot
Copy link
Member

llvmbot commented Mar 18, 2025

@llvm/pr-subscribers-mlir

Author: Ian Wood (IanWood1)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/131733.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+2-2)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+22)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pack.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 108839a4d90e9..d9edabef6693d 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -220,10 +220,10 @@ bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) {
 }
 
 bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
-  assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
-         "permutation must be non-negative");
   llvm::SmallDenseSet<int64_t, 4> seenVals;
   for (auto val : interchange) {
+    if (val < 0 || static_cast<uint64_t>(val) >= interchange.size())
+      return false;
     if (seenVals.count(val))
       return false;
     seenVals.insert(val);
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index f2283db8f89b2..90ceadebbc1fa 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -949,6 +949,28 @@ func.func @transpose_invalid_permutation(%input: tensor<16x32x64xf32>,
 
 // -----
 
+func.func @transpose_out_of_range_permutation(%input: tensor<16x32x64xf32>,
+    %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  // expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [1, 2, 3]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+
+// -----
+
+func.func @transpose_negative_permutation(%input: tensor<16x32x64xf32>,
+    %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  // expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [1, 2, -1]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+// -----
 func.func @transpose_permutated_dims_mismatch(%input: tensor<16x32x64xf32>,
     %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
   // expected-error @+1 {{'linalg.transpose' op dim(result, 0) = 32 doesn't match dim(input, permutation[0]) = 16}}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
index b3ad73e8df8e7..620a21896b0c1 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
@@ -595,7 +595,7 @@ module attributes {transform.with_named_sequence} {
       %unpack = transform.get_consumers_of_result %1[0]
         : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.unpack">)
       %2, %pack_2, %unpack_2 =
-        // expected-error @below {{invalid outer_perm}}
+        // expected-error @below {{"outer_perm" is not a valid permutation}}
         transform.structured.pack_transpose %unpack with_compute_op(%1)
         outer_perm = [1]
         : (!transform.op<"linalg.unpack">, !transform.op<"linalg.generic">)
@@ -623,7 +623,7 @@ module attributes {transform.with_named_sequence} {
       %unpack = transform.get_consumers_of_result %1[0]
         : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.unpack">)
       %2, %pack_2, %unpack_2 =
-        // expected-error @below {{invalid inner_perm}}
+        // expected-error @below {{"inner_perm" is not a valid permutation}}
         transform.structured.pack_transpose %unpack with_compute_op(%1)
         inner_perm = [1]
         : (!transform.op<"linalg.unpack">, !transform.op<"linalg.generic">)

@IanWood1 IanWood1 merged commit fbbb33f into llvm:main Mar 18, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants