Skip to content

Commit d0541b4

Browse files
author
Jeff Niu
committed
[mlir] Add I1 support to DenseArrayAttr
This patch adds a DenseI1ArrayAttr to support arrays of i1. Importantly, the implementation is as a simple `ArrayRef<bool>` instead of using bit compression, which was problematic in DenseElementsAttr. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D130957
1 parent 74940d2 commit d0541b4

File tree

11 files changed

+66
-33
lines changed

11 files changed

+66
-33
lines changed

mlir/include/mlir/IR/BuiltinAttributes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,8 +791,11 @@ class DenseArrayAttr : public DenseArrayBaseAttr {
791791
static bool classof(Attribute attr);
792792
};
793793
template <>
794+
void DenseArrayAttr<bool>::printWithoutBraces(raw_ostream &os) const;
795+
template <>
794796
void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
795797

798+
extern template class DenseArrayAttr<bool>;
796799
extern template class DenseArrayAttr<int8_t>;
797800
extern template class DenseArrayAttr<int16_t>;
798801
extern template class DenseArrayAttr<int32_t>;
@@ -802,6 +805,7 @@ extern template class DenseArrayAttr<double>;
802805
} // namespace detail
803806

804807
// Public name for all the supported DenseArrayAttr
808+
using DenseBoolArrayAttr = detail::DenseArrayAttr<bool>;
805809
using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
806810
using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
807811
using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;

mlir/include/mlir/IR/BuiltinAttributes.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def Builtin_DenseArrayBase : Builtin_Attr<
180180
ArrayRefParameter<"char">:$elements);
181181
let extraClassDeclaration = [{
182182
// All possible supported element type.
183-
enum class EltType { I8, I16, I32, I64, F32, F64 };
183+
enum class EltType { I1, I8, I16, I32, I64, F32, F64 };
184184

185185
/// Allow implicit conversion to ElementsAttr.
186186
operator ElementsAttr() const {
@@ -189,7 +189,8 @@ def Builtin_DenseArrayBase : Builtin_Attr<
189189

190190
/// ElementsAttr implementation.
191191
using ContiguousIterableTypesT =
192-
std::tuple<int8_t, int16_t, int32_t, int64_t, float, double>;
192+
std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
193+
const bool *value_begin_impl(OverloadToken<bool>) const;
193194
const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
194195
const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
195196
const int32_t *value_begin_impl(OverloadToken<int32_t>) const;

mlir/include/mlir/IR/OpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,7 @@ class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryNam
12821282
let storageType = "::mlir::" # denseAttrName;
12831283
let returnType = "::llvm::ArrayRef<" # cppType # ">";
12841284
}
1285+
def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">;
12851286
def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">;
12861287
def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
12871288
def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">;

mlir/lib/AsmParser/AttributeParser.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,12 @@ Attribute Parser::parseDenseArrayAttr() {
845845

846846
if (auto intType = type.dyn_cast<IntegerType>()) {
847847
switch (type.getIntOrFloatBitWidth()) {
848+
case 1:
849+
if (isEmptyList)
850+
result = DenseBoolArrayAttr::get(parser.getContext(), {});
851+
else
852+
result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{});
853+
break;
848854
case 8:
849855
if (isEmptyList)
850856
result = DenseI8ArrayAttr::get(parser.getContext(), {});
@@ -870,7 +876,7 @@ Attribute Parser::parseDenseArrayAttr() {
870876
result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
871877
break;
872878
default:
873-
emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
879+
emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type;
874880
return {};
875881
}
876882
} else if (auto floatType = type.dyn_cast<FloatType>()) {

mlir/lib/AsmParser/Parser.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,15 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
238238

239239
/// Parse an optional integer value from the stream.
240240
OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
241+
// Parse `false` and `true` keywords as 0 and 1 respectively.
242+
if (consumeIf(Token::kw_false)) {
243+
result = false;
244+
return success();
245+
} else if (consumeIf(Token::kw_true)) {
246+
result = true;
247+
return success();
248+
}
249+
241250
Token curToken = getToken();
242251
if (curToken.isNot(Token::integer, Token::minus))
243252
return llvm::None;

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,26 +1860,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
18601860
}
18611861
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
18621862
typeElision = AttrTypeElision::Must;
1863-
switch (denseArrayAttr.getElementType()) {
1864-
case DenseArrayBaseAttr::EltType::I8:
1865-
os << "[:i8";
1866-
break;
1867-
case DenseArrayBaseAttr::EltType::I16:
1868-
os << "[:i16";
1869-
break;
1870-
case DenseArrayBaseAttr::EltType::I32:
1871-
os << "[:i32";
1872-
break;
1873-
case DenseArrayBaseAttr::EltType::I64:
1874-
os << "[:i64";
1875-
break;
1876-
case DenseArrayBaseAttr::EltType::F32:
1877-
os << "[:f32";
1878-
break;
1879-
case DenseArrayBaseAttr::EltType::F64:
1880-
os << "[:f64";
1881-
break;
1882-
}
1863+
os << "[:" << denseArrayAttr.getType().getElementType();
18831864
if (denseArrayAttr.size())
18841865
os << " ";
18851866
denseArrayAttr.printWithoutBraces(os);

mlir/lib/IR/BuiltinAttributes.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,9 @@ DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
732732

733733
ShapedType DenseArrayBaseAttr::getType() const { return getImpl()->type; }
734734

735+
const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
736+
return cast<DenseBoolArrayAttr>().asArrayRef().begin();
737+
}
735738
const int8_t *
736739
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
737740
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
@@ -762,6 +765,9 @@ void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
762765

763766
void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
764767
switch (getElementType()) {
768+
case DenseArrayBaseAttr::EltType::I1:
769+
this->cast<DenseBoolArrayAttr>().printWithoutBraces(os);
770+
return;
765771
case DenseArrayBaseAttr::EltType::I8:
766772
this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
767773
return;
@@ -797,15 +803,20 @@ void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
797803

798804
template <typename T>
799805
void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
800-
ArrayRef<T> values{*this};
801-
llvm::interleaveComma(values, os);
806+
llvm::interleaveComma(asArrayRef(), os);
807+
}
808+
809+
/// Specialization for bool to print `true` or `false`.
810+
template <>
811+
void DenseArrayAttr<bool>::printWithoutBraces(raw_ostream &os) const {
812+
llvm::interleaveComma(asArrayRef(), os,
813+
[&](bool v) { os << (v ? "true" : "false"); });
802814
}
803815

804816
/// Specialization for int8_t for forcing printing as number instead of chars.
805817
template <>
806818
void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
807-
ArrayRef<int8_t> values{*this};
808-
llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
819+
llvm::interleaveComma(asArrayRef(), os, [&](int64_t v) { os << v; });
809820
}
810821

811822
template <typename T>
@@ -816,7 +827,7 @@ void DenseArrayAttr<T>::print(raw_ostream &os) const {
816827
}
817828

818829
/// Parse a single element: generic template for int types, specialized for
819-
/// floating points below.
830+
/// floating point and boolean values below.
820831
template <typename T>
821832
static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
822833
return parser.parseInteger(value);
@@ -880,6 +891,14 @@ namespace {
880891
template <typename T>
881892
struct denseArrayAttrEltTypeBuilder;
882893
template <>
894+
struct denseArrayAttrEltTypeBuilder<bool> {
895+
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I1;
896+
static ShapedType getShapedType(MLIRContext *context,
897+
ArrayRef<int64_t> shape) {
898+
return RankedTensorType::get(shape, IntegerType::get(context, 1));
899+
}
900+
};
901+
template <>
883902
struct denseArrayAttrEltTypeBuilder<int8_t> {
884903
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
885904
static ShapedType getShapedType(MLIRContext *context,
@@ -953,6 +972,7 @@ bool DenseArrayAttr<T>::classof(Attribute attr) {
953972
namespace mlir {
954973
namespace detail {
955974
// Explicit instantiation for all the supported DenseArrayAttr.
975+
template class DenseArrayAttr<bool>;
956976
template class DenseArrayAttr<int8_t>;
957977
template class DenseArrayAttr<int16_t>;
958978
template class DenseArrayAttr<int32_t>;

mlir/test/IR/attribute.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,13 +521,15 @@ func.func @simple_scalar_example() {
521521
//===----------------------------------------------------------------------===//
522522

523523
// CHECK-LABEL: func @dense_array_attr
524-
func.func @dense_array_attr() attributes{
524+
func.func @dense_array_attr() attributes {
525525
// CHECK-SAME: emptyf32attr = [:f32],
526526
emptyf32attr = [:f32],
527527
// CHECK-SAME: emptyf64attr = [:f64],
528528
emptyf64attr = [:f64],
529529
// CHECK-SAME: emptyi16attr = [:i16],
530530
emptyi16attr = [:i16],
531+
// CHECK-SAME: emptyi1attr = [:i1],
532+
emptyi1attr = [:i1],
531533
// CHECK-SAME: emptyi32attr = [:i32],
532534
emptyi32attr = [:i32],
533535
// CHECK-SAME: emptyi64attr = [:i64],
@@ -540,6 +542,8 @@ func.func @dense_array_attr() attributes{
540542
f64attr = [:f64 -142.],
541543
// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10],
542544
i16attr = [:i16 3, 5, -4, 10],
545+
// CHECK-SAME: i1attr = [:i1 true, false, true],
546+
i1attr = [:i1 true, false, true],
543547
// CHECK-SAME: i32attr = [:i32 1024, 453, -6435],
544548
i32attr = [:i32 1024, 453, -6435],
545549
// CHECK-SAME: i64attr = [:i64 -142],
@@ -549,6 +553,8 @@ func.func @dense_array_attr() attributes{
549553
} {
550554
// CHECK: test.dense_array_attr
551555
test.dense_array_attr
556+
// CHECK-SAME: i1attr = [true, false, true]
557+
i1attr = [true, false, true]
552558
// CHECK-SAME: i8attr = [1, -2, 3]
553559
i8attr = [1, -2, 3]
554560
// CHECK-SAME: i16attr = [3, 5, -4, 10]

mlir/test/IR/elements-attr-interface.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
2727
// expected-error@below {{Test iterating `IntegerAttr`: }}
2828
arith.constant dense<> : tensor<0xi64>
2929

30+
// expected-error@below {{Test iterating `bool`: true, false, true, false, true, false}}
31+
arith.constant [:i1 true, false, true, false, true, false]
3032
// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
3133
arith.constant [:i8 10, 11, -12, 13, 14]
3234
// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
272272

273273
def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
274274
let arguments = (ins
275+
DenseBoolArrayAttr:$i1attr,
275276
DenseI8ArrayAttr:$i8attr,
276277
DenseI16ArrayAttr:$i16attr,
277278
DenseI32ArrayAttr:$i32attr,
@@ -281,10 +282,9 @@ def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
281282
DenseI32ArrayAttr:$emptyattr
282283
);
283284
let assemblyFormat = [{
284-
`i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
285-
`i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr
286-
`emptyattr` `=` $emptyattr
287-
attr-dict
285+
`i1attr` `=` $i1attr `i8attr` `=` $i8attr `i16attr` `=` $i16attr
286+
`i32attr` `=` $i32attr `i64attr` `=` $i64attr `f32attr` `=` $f32attr
287+
`f64attr` `=` $f64attr `emptyattr` `=` $emptyattr attr-dict
288288
}];
289289
}
290290

mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ struct TestElementsAttrInterface
4343
if (auto concreteAttr =
4444
attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
4545
switch (concreteAttr.getElementType()) {
46+
case DenseArrayBaseAttr::EltType::I1:
47+
testElementsAttrIteration<bool>(op, elementsAttr, "bool");
48+
break;
4649
case DenseArrayBaseAttr::EltType::I8:
4750
testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
4851
break;

0 commit comments

Comments
 (0)