Skip to content

[mlir] Require folders to produce Values of same type #75887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2023

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 19, 2023

This commit adds extra assertions to OperationFolder and OpBuilder to ensure that the types of the folded SSA values match with the result types of the op. There used to be checks that discard the folded results if the types do not match. This commit makes these checks stricter and turns them into assertions.

Discarding folded results with the wrong type (without failing explicitly) can hide bugs in op folders. Two such bugs became apparent in MLIR (and some more in downstream projects) and are fixed with this change.

Note: The existing type checks were introduced in https://reviews.llvm.org/D95991.

Migration guide: If you see failing assertions (folder produced value of incorrect type; make sure to run with assertions enabled!), run with -debug or dump the operation right before the failing assertion. This will point you to the op that has the broken folder. A common mistake is a mismatch between static/dynamic dimensions (e.g., input has a static dimension but folded result has a dynamic dimension).

@llvmbot
Copy link
Member

llvmbot commented Dec 19, 2023

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds extra assertions to OperationFolder and OpBuilder to ensure that the types of the folded SSA values match with the result types of the op. There used to be checks that discard the folded results if the types do not match. This commit makes these checks stricter and turns them into assertions.

Discarding folded results with the wrong type (without failing explicitly) can hide bugs in op folders. Two such bugs became apparent in MLIR (and some more in downstream projects) and are fixed with this change.

Note: The existing type checks were introduced in https://reviews.llvm.org/D95991.


Full diff: https://github.com/llvm/llvm-project/pull/75887.diff

8 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-4)
  • (modified) mlir/lib/IR/Builders.cpp (+2-3)
  • (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+2-4)
  • (modified) mlir/test/Transforms/test-canonicalize.mlir (-13)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (-10)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (-4)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (-7)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
       setOperand(src);
       return getResult();
     }
+
     // trunci(zexti(a)) -> a
     // trunci(sexti(a)) -> a
-    return src;
+    if (srcType == dstType)
+      return src;
   }
 
   // trunci(trunci(a)) -> trunci(a))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..ac9485326a32ed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1600,11 +1600,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
                                        : 0;
   };
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0)
-    return source;
 
+  unsigned broadcastSrcRank = getRank(source.getType());
   unsigned extractResultRank = getRank(extractOp.getType());
   if (extractResultRank >= broadcastSrcRank)
     return Value();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..c28cbe109c3ffd 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -491,9 +491,8 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
     // Normal values get pushed back directly.
     if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
-      if (value.getType() != expectedType)
-        return cleanupFailure();
-
+      assert(value.getType() == expectedType &&
+             "folder produced value of incorrect type");
       results.push_back(value);
       continue;
     }
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..34b7117a035748 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,8 @@ OperationFolder::processFoldResults(Operation *op,
 
     // Check if the result was an SSA value.
     if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
-      if (repl.getType() != op->getResult(i).getType()) {
-        results.clear();
-        return failure();
-      }
+      assert(repl.getType() == op->getResult(i).getType() &&
+             "folder produced value of incorrect type");
       results.emplace_back(repl);
       continue;
     }
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index bc463fefe65342..4f0095ed7e8cf4 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -70,19 +70,6 @@ func.func @test_commutative_multi_cst(%arg0: i32, %arg1: i32) -> (i32, i32) {
   return %y, %z: i32, i32
 }
 
-// CHECK-LABEL: func @typemismatch
-
-func.func @typemismatch() -> i32 {
-  %c42 = arith.constant 42.0 : f32
-
-  // The "passthrough_fold" folder will naively return its operand, but we don't
-  // want to fold here because of the type mismatch.
-
-  // CHECK: "test.passthrough_fold"
-  %0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
-  return %0 : i32
-}
-
 // CHECK-LABEL: test_dialect_canonicalizer
 func.func @test_dialect_canonicalizer() -> (i32) {
   %0 = "test.dialect_canonicalizable"() : () -> (i32)
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 6897b6f95f0d05..d8cf6e4719cede 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -310,16 +310,6 @@ builtin.module {
 
 // -----
 
-// The "passthrough_fold" folder will naively return its operand, but we don't
-// want to fold here because of the type mismatch.
-func.func @typemismatch(%arg: f32) -> i32 {
-  // expected-remark@+1 {{op 'test.passthrough_fold' is not legalizable}}
-  %0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
-  "test.return"(%0) : (i32) -> ()
-}
-
-// -----
-
 // expected-remark @below {{applyPartialConversion failed}}
 module {
   func.func private @callee(%0 : f32) -> f32
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 21400a60e65321..a1b30705f16a98 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -542,10 +542,6 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
-  return getOperand();
-}
-
 OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
   int64_t sum = 0;
   if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 96f66c2ca06ecf..70ccc71883e3c1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1363,13 +1363,6 @@ def TestOpFoldWithFoldAdaptor
   let hasFolder = 1;
 }
 
-// An op that always fold itself.
-def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
-  let arguments = (ins AnyType:$op);
-  let results = (outs AnyType);
-  let hasFolder = 1;
-}
-
 def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
   let arguments = (ins);
   let results = (outs I32);

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix in flang, LGTM

This commit adds extra assertions to `OperationFolder` and `OpBuilder` to ensure that the types of the folded SSA values match with the result types of the op. There used to be checks that discard the folded results if the types do not match. This commit makes these checks stricter and turns them into assertions.

Discarding folded results with the wrong type can hide bugs in op folders. Two such bugs became apparent and are fixed with this change.

Note: The existing type checks were introduced in https://reviews.llvm.org/D95991.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
@matthias-springer matthias-springer merged commit f10302e into llvm:main Dec 20, 2023
AaronStGeorge pushed a commit to AaronStGeorge/torch-mlir that referenced this pull request Jan 28, 2024
renxida pushed a commit to llvm/torch-mlir that referenced this pull request Jan 30, 2024
We were seeing some assertion failures after some checks around folders
were tightened up in LLVM:
llvm/llvm-project#75887 . This PR essentially
moves the logic that used to be applied at the LLVM level into the
folder, which seems to be the suggested fix.

I'm not sure if the IR that caused issues for us _should_ be valid?
```
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
```
A better fix might be to create a verifier ensuring the result of
`aten.detach` has the same type as its operand.

---------

Co-authored-by: aaron-stgeorge <[email protected]>
qingyunqu pushed a commit to llvm/torch-mlir that referenced this pull request May 15, 2024
Similar to #2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: llvm/llvm-project#75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
BaneTrifa pushed a commit to BaneTrifa/torch-mlir that referenced this pull request May 24, 2024
Similar to llvm#2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: llvm/llvm-project#75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:arith mlir:core MLIR Core Infrastructure mlir:tosa mlir:vector mlir:vectorops mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants