Skip to content

[mlir][arith] Assert preconditions in BitcastOp::fold. #100743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 29, 2024

Conversation

ingomueller-net
Copy link
Contributor

@ingomueller-net ingomueller-net commented Jul 26, 2024

This PR adds an assertion to BitcastOp::fold that fails if that function is called on invalid IR. That can happen when patterns, passes, etc. create (invalid) IR using builders and folding is triggered on that IR before verification, for example, through OpBuilder::createOrFold. The new assert triggers earlier than previously in order to help getting to the root cause faster.

Original description (obsolete): This PR prevents BitcastOp::fold to create invalid IntegerAttr and FloatAttr values, which result in failed assertions. This can happen if the input IR is invalid. The PR adds tests for whether the to-be-created attribute verifies and returns early from foldif it doesn't.

@llvmbot
Copy link
Member

llvmbot commented Jul 26, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ingo Müller (ingomueller-net)

Changes

This PR prevents BitcastOp::fold to create invalid IntegerAttr and FloatAttr values, which result in failed assertions. This can happen if the input IR is invalid. The PR adds tests for whether the to-be-created attribute verifies and returns early from foldif it doesn't.


Full diff: https://github.com/llvm/llvm-project/pull/100743.diff

3 Files Affected:

  • (added) mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp (+42)
  • (added) mlir/unittests/Dialect/Arith/CMakeLists.txt (+7)
  • (modified) mlir/unittests/Dialect/CMakeLists.txt (+1)
diff --git a/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp b/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
new file mode 100644
index 0000000000000..2debc7706ce48
--- /dev/null
+++ b/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
@@ -0,0 +1,42 @@
+//===- ArithOpsFoldersTest.cpp - unit tests for arith op folders ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+TEST(BitcastOpTest, FoldInteger) {
+  MLIRContext context;
+  context.loadDialect<arith::ArithDialect>();
+  auto loc = UnknownLoc::get(&context);
+  auto module = ModuleOp::create(loc);
+  OpBuilder builder(module.getBodyRegion());
+  Value i32Val = builder.create<arith::ConstantOp>(
+      loc, builder.getI32Type(), builder.getI32IntegerAttr(0));
+  // This would create an invalid op: `bitcast` can't cast different bitwidths.
+  builder.createOrFold<arith::BitcastOp>(loc, builder.getI64Type(), i32Val);
+  ASSERT_TRUE(failed(verify(module)));
+}
+
+TEST(BitcastOpTest, FoldFloat) {
+  MLIRContext context;
+  context.loadDialect<arith::ArithDialect>();
+  auto loc = UnknownLoc::get(&context);
+  auto module = ModuleOp::create(loc);
+  OpBuilder builder(module.getBodyRegion());
+  Value f32Val = builder.create<arith::ConstantOp>(loc, builder.getF32Type(),
+                                                   builder.getF32FloatAttr(0));
+  // This would create an invalid op: `bitcast` can't cast different bitwidths.
+  builder.createOrFold<arith::BitcastOp>(loc, builder.getF64Type(), f32Val);
+  ASSERT_TRUE(failed(verify(module)));
+}
+} // namespace
diff --git a/mlir/unittests/Dialect/Arith/CMakeLists.txt b/mlir/unittests/Dialect/Arith/CMakeLists.txt
new file mode 100644
index 0000000000000..ac6b701529d3f
--- /dev/null
+++ b/mlir/unittests/Dialect/Arith/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRArithOpsTests
+  ArithOpsFoldersTest.cpp
+)
+target_link_libraries(MLIRArithOpsTests
+  PRIVATE
+  MLIRArithDialect
+)
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 90a75d5a46ad9..e88e8c61fcb13 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
   MLIRIR
   MLIRDialect)
 
+add_subdirectory(Arith)
 add_subdirectory(ArmSME)
 add_subdirectory(Index)
 add_subdirectory(LLVMIR)

This PR prevents `BitcastOp::fold` to create invalid `IntegerAttr` and
`FloatAttr` values, which result in failed assertions. This can happen
if the input IR is invalid. The PR adds tests for whether the
to-be-created attribute verifies and returns early from `fold`if it
doesn't.

Signed-off-by: Ingo Müller <[email protected]>
@ingomueller-net ingomueller-net force-pushed the fix-arith-bitcast-assert branch from 52ef358 to a8b01cd Compare July 26, 2024 13:37
@ingomueller-net
Copy link
Contributor Author

I am not actually sure if this is the correct fix. Maybe a folder can always assume valid IR?

When I ran into this problem, it took me quite a long time to figure out what was actually going on, so I think at the very least, the reporting of the problem should be improved. If folders can indeed assume valid IR, maybe the current early-exit conditions should be turned into asserts?

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which context are we folding invalid IR?

builder.createOrFold<arith::BitcastOp>(loc, builder.getF64Type(), f32Val);
ASSERT_TRUE(failed(verify(module)));
}
} // namespace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't expect C++ unit-tests for this kind of thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I can't trigger the assert with a list test because there the input IR would be valid.

Copy link
Contributor Author

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which context are we folding invalid IR?

Before the folder is called, the IR looks like this:

%0 = arith.constant 0 : i32
%1 = arith.bitcast %0 : i32 to i64

That's invalid because i32 is not BitcastOp::areCastCompatible to i64.

The folder doesn't verify that but instead attemps to create an IntegerAttr with type i32 and an APInt value of type i64, which breaks an assertion in the call stack IntegerAttr::get.

builder.createOrFold<arith::BitcastOp>(loc, builder.getF64Type(), f32Val);
ASSERT_TRUE(failed(verify(module)));
}
} // namespace
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I can't trigger the assert with a list test because there the input IR would be valid.

@joker-eph
Copy link
Collaborator

In which context are we folding invalid IR?

Before the folder is called, the IR looks like this:

%0 = arith.constant 0 : i32
%1 = arith.bitcast %0 : i32 to i64

That's invalid because i32 is not BitcastOp::areCastCompatible to i64.

The folder doesn't verify that but instead attemps to create an IntegerAttr with type i32 and an APInt value of type i64, which breaks an assertion in the call stack IntegerAttr::get.

It's not clear to me when should the folder be resilient to invalid IR though: this is what I mean about "In which context are we folding invalid IR?" ; this is a general question beyond the specific IR you have at hand: I need to know how you end up in this situation instead of the specific invalid IR that triggers this crash.

@ingomueller-net
Copy link
Contributor Author

It's not clear to me when should the folder be resilient to invalid IR though: this is what I mean about "In which context are we folding invalid IR?" ; this is a general question beyond the specific IR you have at hand: I need to know how you end up in this situation instead of the specific invalid IR that triggers this crash.

Oh, got it! This was triggered by a pattern that is part of a dialect conversion that calls a helper function castTo(OpBuilder&, Type, Value), which inserts "the right cast" using builder.createOrFold<arith::BitcastOp>(...) like in the unittest. The implementation was erroneous: it did not check for bit widths and inserted arith.bitcast even when the types had different bit width.

Does that give you the context you were looking for?

@joker-eph
Copy link
Collaborator

It's not clear to me when should the folder be resilient to invalid IR though: this is what I mean about "In which context are we folding invalid IR?" ; this is a general question beyond the specific IR you have at hand: I need to know how you end up in this situation instead of the specific invalid IR that triggers this crash.

Oh, got it! This was triggered by a pattern that is part of a dialect conversion that calls a helper function castTo(OpBuilder&, Type, Value), which inserts "the right cast" using builder.createOrFold<arith::BitcastOp>(...) like in the unittest. The implementation was erroneous: it did not check for bit widths and inserted arith.bitcast even when the types had different bit width.

Does that give you the context you were looking for?

Yes, thanks!

Why wouldn't we fix the pattern instead?

@ingomueller-net
Copy link
Contributor Author

Why wouldn't we fix the pattern instead?

Yes, absolutely! Even with the current state of the PR, my pattern would create invalid IR, but instead of crashing, verification would fail.

I take it that folders can generally assume that they work on valid IR, right? Then the current behavior of the folder (before this PR) is actually fine.

What we might still be able to improve is to help finding the root cause easier when a pattern is broken, like mine. Then BitcastOp::fold could assert that the types are cast compatible, with a message that says that the input IR is broken. Would that be helpful?

@joker-eph
Copy link
Collaborator

I take it that folders can generally assume that they work on valid IR, right?

I believe so. This is why you had to create a c++ unit-tests and none exists yet.

Then BitcastOp::fold could assert that the types are cast compatible, with a message that says that the input IR is broken.

Sure!

@ingomueller-net
Copy link
Contributor Author

OK, done: there is only one assert left, that for the bit width. I tried to construct cases to trigger the other checks but ended up convincing myself that they can't be triggered if the bit width is correct, so having that one test should be enough. This reduces the PR to adding just two lines...

@ingomueller-net ingomueller-net changed the title [mlir][arith] Check for valid IR in BitcastOp::fold. [mlir][arith] Assert preconditions in BitcastOp::fold. Jul 29, 2024
@ingomueller-net ingomueller-net merged commit 77655f4 into llvm:main Jul 29, 2024
7 checks passed
@ingomueller-net ingomueller-net deleted the fix-arith-bitcast-assert branch July 29, 2024 12:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants