Skip to content

Commit dff7359

Browse files
committed
Reapply "[MLIR][TableGen] Error on APInt parameter without custom comparator (llvm#135970)"
This reapplies commit 4bcc414. This reverts commit 450c366.
1 parent e272649 commit dff7359

File tree

7 files changed

+39
-50
lines changed

7 files changed

+39
-50
lines changed

mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,11 @@ def BitVectorAttr : AttrDef<SMTDialect, "BitVector", [
4545
present).
4646
}];
4747

48-
let parameters = (ins "llvm::APInt":$value);
48+
let parameters = (ins APIntParameter<"">:$value);
4949

5050
let hasCustomAssemblyFormat = true;
5151
let genVerifyDecl = true;
5252

53-
// We need to manually define the storage class because the generated one is
54-
// buggy (because the APInt asserts matching bitwidth in the `==` operator and
55-
// the generated storage uses that directly.
56-
// Alternatively: add a type parameter to redundantly store the bitwidth of
57-
// of the attribute type, it it's in the order before the 'value' it will be
58-
// checked before the APInt equality (this is the reason it works for the
59-
// builtin integer attribute), but would be more fragile (and we'd store
60-
// duplicate data).
61-
let genStorageClass = false;
62-
6353
let builders = [
6454
AttrBuilder<(ins "llvm::StringRef":$value)>,
6555
AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>,

mlir/include/mlir/IR/BuiltinAttributes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
700700
false // A bool, i.e. i1, value.
701701
```
702702
}];
703-
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
703+
let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value);
704704
let builders = [
705705
AttrBuilderWithInferredContext<(ins "Type":$type,
706706
"const APInt &":$value), [{

mlir/include/mlir/TableGen/AttrOrTypeDef.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ class AttrOrTypeParameter {
6868
/// If specified, get the custom allocator code for this parameter.
6969
std::optional<StringRef> getAllocator() const;
7070

71-
/// If specified, get the custom comparator code for this parameter.
71+
/// Return true if user defined comparator is specified.
72+
bool hasCustomComparator() const;
73+
74+
/// Get the custom comparator code for this parameter or fallback to the
75+
/// default.
7276
StringRef getComparator() const;
7377

7478
/// Get the C++ type of this parameter.

mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,6 @@ using namespace mlir::smt;
2121
// BitVectorAttr
2222
//===----------------------------------------------------------------------===//
2323

24-
namespace mlir {
25-
namespace smt {
26-
namespace detail {
27-
struct BitVectorAttrStorage : public mlir::AttributeStorage {
28-
using KeyTy = APInt;
29-
BitVectorAttrStorage(APInt value) : value(std::move(value)) {}
30-
31-
KeyTy getAsKey() const { return value; }
32-
33-
// NOTE: the implementation of this operator is the reason we need to define
34-
// the storage manually. The auto-generated version would just do the direct
35-
// equality check of the APInt, but that asserts the bitwidth of both to be
36-
// the same, leading to a crash. This implementation, therefore, checks for
37-
// matching bit-width beforehand.
38-
bool operator==(const KeyTy &key) const {
39-
return (value.getBitWidth() == key.getBitWidth() && value == key);
40-
}
41-
42-
static llvm::hash_code hashKey(const KeyTy &key) {
43-
return llvm::hash_value(key);
44-
}
45-
46-
static BitVectorAttrStorage *
47-
construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key) {
48-
return new (allocator.allocate<BitVectorAttrStorage>())
49-
BitVectorAttrStorage(std::move(key));
50-
}
51-
52-
APInt value;
53-
};
54-
} // namespace detail
55-
} // namespace smt
56-
} // namespace mlir
57-
58-
APInt BitVectorAttr::getValue() const { return getImpl()->value; }
59-
6024
LogicalResult BitVectorAttr::verify(
6125
function_ref<InFlightDiagnostic()> emitError,
6226
APInt value) { // NOLINT(performance-unnecessary-value-param)

mlir/lib/TableGen/AttrOrTypeDef.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ std::optional<StringRef> AttrOrTypeParameter::getAllocator() const {
278278
return getDefValue<StringInit>("allocator");
279279
}
280280

281+
bool AttrOrTypeParameter::hasCustomComparator() const {
282+
return getDefValue<StringInit>("comparator").has_value();
283+
}
284+
281285
StringRef AttrOrTypeParameter::getComparator() const {
282286
return getDefValue<StringInit>("comparator").value_or("$_lhs == $_rhs");
283287
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: not mlir-tblgen -gen-attrdef-decls -I %S/../../include %s 2>&1 | FileCheck %s
2+
3+
include "mlir/IR/AttrTypeBase.td"
4+
include "mlir/IR/OpBase.td"
5+
6+
def Test_Dialect: Dialect {
7+
let name = "TestDialect";
8+
let cppNamespace = "::test";
9+
}
10+
11+
def RawAPIntAttr : AttrDef<Test_Dialect, "RawAPInt"> {
12+
let mnemonic = "raw_ap_int";
13+
let parameters = (ins "APInt":$value);
14+
let hasCustomAssemblyFormat = 1;
15+
}
16+
17+
// CHECK: apint-param-error.td:11:5: error: Using a raw APInt parameter

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,18 @@ void DefGen::emitStorageClass() {
678678
emitConstruct();
679679
// Emit the storage class members as public, at the very end of the struct.
680680
storageCls->finalize();
681-
for (auto &param : params)
681+
for (auto &param : params) {
682+
if (param.getCppType().contains("APInt") && !param.hasCustomComparator()) {
683+
PrintFatalError(
684+
def.getLoc(),
685+
"Using a raw APInt parameter without a custom comparator is "
686+
"not supported because an assert in the equality operator is "
687+
"triggered when the two APInts have different bit widths. This can "
688+
"lead to unexpected crashes. Use an `APIntParameter` or "
689+
"provide a custom comparator.");
690+
}
682691
storageCls->declare<Field>(param.getCppType(), param.getName());
692+
}
683693
}
684694

685695
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)