Skip to content

Commit a8b01cd

Browse files
[mlir][arith] Check for valid IR in BitcastOp::fold.
This PR prevents `BitcastOp::fold` to create invalid `IntegerAttr` and `FloatAttr` values, which result in failed assertions. This can happen if the input IR is invalid. The PR adds tests for whether the to-be-created attribute verifies and returns early from `fold`if it doesn't. Signed-off-by: Ingo Müller <[email protected]>
1 parent 35dfe80 commit a8b01cd

File tree

4 files changed

+72
-3
lines changed

4 files changed

+72
-3
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,9 +1740,24 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
17401740
? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
17411741
: llvm::cast<IntegerAttr>(operand).getValue();
17421742

1743-
if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1744-
return FloatAttr::get(resType,
1745-
APFloat(resFloatType.getFloatSemantics(), bits));
1743+
/// If bitwidth aren't the same, don't fold.
1744+
if (resType.getIntOrFloatBitWidth() != bits.getBitWidth())
1745+
return {};
1746+
1747+
MLIRContext *ctx = getContext();
1748+
auto emitErrorFn = [=] { return ::emitError(UnknownLoc::get(ctx)); };
1749+
1750+
if (auto resFloatType = llvm::dyn_cast<FloatType>(resType)) {
1751+
/// If bits don't represent a valid float, don't fold.
1752+
APFloat floatBits(resFloatType.getFloatSemantics(), bits);
1753+
if (failed(FloatAttr::verify(emitErrorFn, resType, floatBits)))
1754+
return {};
1755+
return FloatAttr::get(resType, floatBits);
1756+
}
1757+
1758+
/// If bits don't represent a valid integer, don't fold.
1759+
if (failed(IntegerAttr::verify(emitErrorFn, resType, bits)))
1760+
return {};
17461761
return IntegerAttr::get(resType, bits);
17471762
}
17481763

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//===- ArithOpsFoldersTest.cpp - unit tests for arith op folders ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/IR/BuiltinOps.h"
11+
#include "mlir/IR/Verifier.h"
12+
#include "gtest/gtest.h"
13+
14+
using namespace mlir;
15+
16+
namespace {
17+
// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
18+
// #100743,
19+
TEST(BitcastOpTest, FoldInteger) {
20+
MLIRContext context;
21+
context.loadDialect<arith::ArithDialect>();
22+
auto loc = UnknownLoc::get(&context);
23+
auto module = ModuleOp::create(loc);
24+
OpBuilder builder(module.getBodyRegion());
25+
Value i32Val = builder.create<arith::ConstantOp>(
26+
loc, builder.getI32Type(), builder.getI32IntegerAttr(0));
27+
// This would create an invalid op: `bitcast` can't cast different bitwidths.
28+
builder.createOrFold<arith::BitcastOp>(loc, builder.getI64Type(), i32Val);
29+
ASSERT_TRUE(failed(verify(module)));
30+
}
31+
32+
// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
33+
// #100743,
34+
TEST(BitcastOpTest, FoldFloat) {
35+
MLIRContext context;
36+
context.loadDialect<arith::ArithDialect>();
37+
auto loc = UnknownLoc::get(&context);
38+
auto module = ModuleOp::create(loc);
39+
OpBuilder builder(module.getBodyRegion());
40+
Value f32Val = builder.create<arith::ConstantOp>(loc, builder.getF32Type(),
41+
builder.getF32FloatAttr(0));
42+
// This would create an invalid op: `bitcast` can't cast different bitwidths.
43+
builder.createOrFold<arith::BitcastOp>(loc, builder.getF64Type(), f32Val);
44+
ASSERT_TRUE(failed(verify(module)));
45+
}
46+
} // namespace
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_mlir_unittest(MLIRArithOpsTests
2+
ArithOpsFoldersTest.cpp
3+
)
4+
target_link_libraries(MLIRArithOpsTests
5+
PRIVATE
6+
MLIRArithDialect
7+
)

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
66
MLIRIR
77
MLIRDialect)
88

9+
add_subdirectory(Arith)
910
add_subdirectory(ArmSME)
1011
add_subdirectory(Index)
1112
add_subdirectory(LLVMIR)

0 commit comments

Comments
 (0)