Skip to content

Commit 7714b40

Browse files
committed
[mlir] introduce "encoding" attribute to tensor type
This CL introduces a generic attribute (called "encoding") on tensors. The attribute currently does not carry any concrete information, but the type system already correctly determines that tensor<8xi1,123> != tensor<8xi1,321>. The attribute will be given meaning through an interface in subsequent CLs. See ongoing discussion on discourse: [RFC] Introduce a sparse tensor type to core MLIR https://llvm.discourse.group/t/rfc-introduce-a-sparse-tensor-type-to-core-mlir/2944 A sparse tensor will look something like this: ``` // named alias with all properties we hold dear: #CSR = { // individual named attributes } // actual sparse tensor type: tensor<?x?xf64, #CSR> ``` I see the following rough 5 step plan going forward: (1) introduce this format attribute in this CL, currently still empty (2) introduce attribute interface that gives it "meaning", focused on sparse in first phase (3) rewrite sparse compiler to use new type, remove linalg interface and "glue" (4) teach passes to deal with new attribute, by rejecting/asserting on non-empty attribute as simplest solution, or doing meaningful rewrite in the longer run (5) add FE support, document, test, publicize new features, extend "format" meaning to other domains if useful Reviewed By: stellaraccident, bondhugula Differential Revision: https://reviews.llvm.org/D99548
1 parent 8508a63 commit 7714b40

File tree

13 files changed

+114
-49
lines changed

13 files changed

+114
-49
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
extern "C" {
2323
#endif
2424

25+
/// Returns an empty attribute.
26+
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull();
27+
2528
//===----------------------------------------------------------------------===//
2629
// Affine map attribute.
2730
//===----------------------------------------------------------------------===//

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,20 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type);
188188
/// Checks whether the given type is an unranked tensor type.
189189
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type);
190190

191-
/// Creates a tensor type of a fixed rank with the given shape and element type
192-
/// in the same context as the element type. The type is owned by the context.
191+
/// Creates a tensor type of a fixed rank with the given shape, element type,
192+
/// and optional encoding in the same context as the element type. The type is
193+
/// owned by the context. Tensor types without any specific encoding field
194+
/// should assign mlirAttributeGetNull() to this parameter.
193195
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank,
194196
const int64_t *shape,
195-
MlirType elementType);
197+
MlirType elementType,
198+
MlirAttribute encoding);
196199

197200
/// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on
198201
/// illegal arguments, emitting appropriate diagnostics.
199-
MLIR_CAPI_EXPORTED MlirType
200-
mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
201-
const int64_t *shape, MlirType elementType);
202+
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(
203+
MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType,
204+
MlirAttribute encoding);
202205

203206
/// Creates an unranked tensor type with the given element type in the same
204207
/// context as the element type. The type is owned by the context.

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -636,9 +636,10 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> {
636636
Syntax:
637637

638638
```
639-
tensor-type ::= `tensor` `<` dimension-list type `>`
639+
tensor-type ::= `tensor` `<` dimension-list type (`,` encoding)? `>`
640640
dimension-list ::= (dimension `x`)*
641641
dimension ::= `?` | decimal-literal
642+
encoding ::= attribute-value
642643
```
643644

644645
Values with tensor type represents aggregate N-dimensional data values, and
@@ -654,6 +655,14 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> {
654655
[`dim` operation](Dialects/Standard.md#dim-operation) returns the size of a
655656
dimension from a value of tensor type.
656657

658+
The `encoding` attribute provides additional information on the tensor.
659+
An empty attribute denotes a straightforward tensor without any specific
660+
structure. But particular properties, like sparsity or other specific
661+
characteristics of the data of the tensor can be encoded through this
662+
attribute. The semantics are defined by a type and attribute interface
663+
and must be respected by all passes that operate on tensor types.
664+
TODO: provide this interface, and document it further.
665+
657666
Note: hexadecimal integer literals are not allowed in tensor type
658667
declarations to avoid confusion between `0xf32` and `0 x f32`. Zero sizes
659668
are allowed in tensors and treated as other sizes, e.g.,
@@ -681,18 +690,24 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> {
681690

682691
// Zero-element tensor of f32 type (hexadecimal literals not allowed here).
683692
tensor<0xf32>
693+
694+
// Tensor with an encoding attribute (where #ENCODING is a named alias).
695+
tensor<?x?xf64, #ENCODING>
684696
```
685697
}];
686698
let parameters = (ins
687699
ArrayRefParameter<"int64_t">:$shape,
688-
"Type":$elementType
700+
"Type":$elementType,
701+
"Attribute":$encoding
689702
);
690703

691704
let builders = [
692705
TypeBuilderWithInferredContext<(ins
693-
"ArrayRef<int64_t>":$shape, "Type":$elementType
706+
"ArrayRef<int64_t>":$shape,
707+
"Type":$elementType,
708+
CArg<"Attribute", "{}">:$encoding
694709
), [{
695-
return $_get(elementType.getContext(), shape, elementType);
710+
return $_get(elementType.getContext(), shape, elementType, encoding);
696711
}]>
697712
];
698713
let skipDefaultBuilders = 1;

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,9 @@ class PyDenseElementsAttribute
502502
MlirType mlirElementType, py::buffer_info &arrayInfo) {
503503
SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
504504
arrayInfo.shape.begin() + arrayInfo.ndim);
505-
auto shapedType =
506-
mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
505+
MlirAttribute encodingAttr = mlirAttributeGetNull();
506+
auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
507+
mlirElementType, encodingAttr);
507508
intptr_t numElements = arrayInfo.size;
508509
const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
509510
return ctor(shapedType, numElements, contents);

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "PybindUtils.h"
1212

13+
#include "mlir-c/BuiltinAttributes.h"
1314
#include "mlir-c/BuiltinTypes.h"
1415

1516
namespace py = pybind11;
@@ -381,8 +382,9 @@ class PyRankedTensorType
381382
"get",
382383
[](std::vector<int64_t> shape, PyType &elementType,
383384
DefaultingPyLocation loc) {
385+
MlirAttribute encodingAttr = mlirAttributeGetNull();
384386
MlirType t = mlirRankedTensorTypeGetChecked(
385-
loc, shape.size(), shape.data(), elementType);
387+
loc, shape.size(), shape.data(), elementType, encodingAttr);
386388
// TODO: Rework error reporting once diagnostic engine is exposed
387389
// in C API.
388390
if (mlirTypeIsNull(t)) {

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
using namespace mlir;
1717

18+
MlirAttribute mlirAttributeGetNull() { return {nullptr}; }
19+
1820
//===----------------------------------------------------------------------===//
1921
// Affine map attribute.
2022
//===----------------------------------------------------------------------===//

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,18 +191,19 @@ bool mlirTypeIsAUnrankedTensor(MlirType type) {
191191
}
192192

193193
MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
194-
MlirType elementType) {
194+
MlirType elementType, MlirAttribute encoding) {
195195
return wrap(RankedTensorType::get(
196-
llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
197-
unwrap(elementType)));
196+
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
197+
unwrap(encoding)));
198198
}
199199

200200
MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
201201
const int64_t *shape,
202-
MlirType elementType) {
202+
MlirType elementType,
203+
MlirAttribute encoding) {
203204
return wrap(RankedTensorType::getChecked(
204205
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
205-
unwrap(elementType)));
206+
unwrap(elementType), unwrap(encoding)));
206207
}
207208

208209
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1866,7 +1866,13 @@ void ModulePrinter::printType(Type type) {
18661866
os << dim;
18671867
os << 'x';
18681868
}
1869-
os << tensorTy.getElementType() << '>';
1869+
os << tensorTy.getElementType();
1870+
// Only print the encoding attribute value if set.
1871+
if (tensorTy.getEncoding()) {
1872+
os << ", ";
1873+
printAttribute(tensorTy.getEncoding());
1874+
}
1875+
os << '>';
18701876
})
18711877
.Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
18721878
os << "tensor<*x";

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,12 @@ bool TensorType::isValidElementType(Type type) {
441441

442442
LogicalResult
443443
RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
444-
ArrayRef<int64_t> shape, Type elementType) {
444+
ArrayRef<int64_t> shape, Type elementType,
445+
Attribute encoding) {
445446
for (int64_t s : shape)
446447
if (s < -1)
447448
return emitError() << "invalid tensor dimension size";
449+
// TODO: verify contents of encoding attribute.
448450
return checkTensorElementType(emitError, elementType);
449451
}
450452

mlir/lib/Parser/TypeParser.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,23 @@ Type Parser::parseTensorType() {
409409
// Parse the element type.
410410
auto elementTypeLoc = getToken().getLoc();
411411
auto elementType = parseType();
412+
413+
// Parse an optional encoding attribute.
414+
Attribute encoding;
415+
if (consumeIf(Token::comma))
416+
encoding = parseAttribute();
417+
412418
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
413419
return nullptr;
414420
if (!TensorType::isValidElementType(elementType))
415421
return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
416422

417-
if (isUnranked)
423+
if (isUnranked) {
424+
if (encoding)
425+
return emitError("cannot apply encoding to unranked tensor"), nullptr;
418426
return UnrankedTensorType::get(elementType);
419-
return RankedTensorType::get(dimensions, elementType);
427+
}
428+
return RankedTensorType::get(dimensions, elementType, encoding);
420429
}
421430

422431
/// Parse a tuple type.

mlir/test/CAPI/ir.c

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,8 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
438438
mlirOperationSetAttributeByName(
439439
operation, mlirStringRefCreateFromCString("elts"),
440440
mlirDenseElementsAttrInt32Get(
441-
mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4,
442-
eltsData));
441+
mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
442+
mlirAttributeGetNull()), 4, eltsData));
443443
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
444444
mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
445445
mlirOpPrintingFlagsPrintGenericOpForm(flags);
@@ -687,8 +687,8 @@ static int printBuiltinTypes(MlirContext ctx) {
687687
// CHECK: vector<2x3xf32>
688688

689689
// Ranked tensor type.
690-
MlirType rankedTensor =
691-
mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
690+
MlirType rankedTensor = mlirRankedTensorTypeGet(
691+
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
692692
if (!mlirTypeIsATensor(rankedTensor) ||
693693
!mlirTypeIsARankedTensor(rankedTensor))
694694
return 16;
@@ -889,24 +889,30 @@ int printBuiltinAttributes(MlirContext ctx) {
889889
int64_t ints64[] = {0, 1};
890890
float floats[] = {0.0f, 1.0f};
891891
double doubles[] = {0.0, 1.0};
892+
MlirAttribute encoding = mlirAttributeGetNull();
892893
MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
893-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools);
894+
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
895+
2, bools);
894896
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
895-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2,
896-
uints32);
897+
mlirRankedTensorTypeGet(2, shape,
898+
mlirIntegerTypeUnsignedGet(ctx, 32), encoding),
899+
2, uints32);
897900
MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
898-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2,
899-
ints32);
901+
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
902+
2, ints32);
900903
MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
901-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2,
902-
uints64);
904+
mlirRankedTensorTypeGet(2, shape,
905+
mlirIntegerTypeUnsignedGet(ctx, 64), encoding),
906+
2, uints64);
903907
MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
904-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2,
905-
ints64);
908+
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
909+
2, ints64);
906910
MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
907-
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats);
911+
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
912+
2, floats);
908913
MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
909-
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles);
914+
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding),
915+
2, doubles);
910916

911917
if (!mlirAttributeIsADenseElements(boolElements) ||
912918
!mlirAttributeIsADenseElements(uint32Elements) ||
@@ -943,19 +949,24 @@ int printBuiltinAttributes(MlirContext ctx) {
943949
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
944950

945951
MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
946-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1);
952+
mlirRankedTensorTypeGet(2, shape,
953+
mlirIntegerTypeGet(ctx, 1), encoding), 1);
947954
MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
948-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
955+
mlirRankedTensorTypeGet(2, shape,
956+
mlirIntegerTypeGet(ctx, 32), encoding), 1);
949957
MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
950-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
958+
mlirRankedTensorTypeGet(2, shape,
959+
mlirIntegerTypeGet(ctx, 32), encoding), 1);
951960
MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
952-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
961+
mlirRankedTensorTypeGet(2, shape,
962+
mlirIntegerTypeGet(ctx, 64), encoding), 1);
953963
MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
954-
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
964+
mlirRankedTensorTypeGet(2, shape,
965+
mlirIntegerTypeGet(ctx, 64), encoding), 1);
955966
MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
956-
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f);
967+
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
957968
MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
958-
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0);
969+
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 1.0);
959970

960971
if (!mlirAttributeIsADenseElements(splatBool) ||
961972
!mlirDenseElementsAttrIsSplat(splatBool) ||
@@ -1024,13 +1035,14 @@ int printBuiltinAttributes(MlirContext ctx) {
10241035
int64_t indices[] = {4, 7};
10251036
int64_t two = 2;
10261037
MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
1027-
mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2,
1028-
indices);
1038+
mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding),
1039+
2, indices);
10291040
MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
1030-
mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats);
1041+
mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding),
1042+
2, floats);
10311043
MlirAttribute sparseAttr = mlirSparseElementsAttribute(
1032-
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr,
1033-
valuesAttr);
1044+
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
1045+
indicesAttr, valuesAttr);
10341046
mlirAttributeDump(sparseAttr);
10351047
// CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32>
10361048

mlir/test/IR/invalid.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ func @memref_zero_stride(memref<42x42xi8, offset: ?, strides: [0, ?]>) // expect
100100

101101
// -----
102102

103+
func @tensor_encoding_mismatch(%arg0: tensor<8xi32, "enc">) -> (tensor<8xi32>) { // expected-note {{prior use here}}
104+
return %arg0: tensor<8xi32> // expected-error {{use of value '%arg0' expects different type than prior uses: 'tensor<8xi32>' vs 'tensor<8xi32, "enc">'}}
105+
}
106+
107+
// -----
108+
103109
func @bad_branch() {
104110
^bb12:
105111
br ^missing // expected-error {{reference to an undefined block}}

mlir/test/IR/parser.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ func private @vectors(vector<1 x f32>, vector<2x4xf32>)
7777
func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
7878
tensor<1x?x4x?x?xi32>, tensor<i8>)
7979

80+
// CHECK: func private @tensor_encoding(tensor<16x32xf64, "sparse">)
81+
func private @tensor_encoding(tensor<16x32xf64, "sparse">)
82+
8083
// CHECK: func private @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>)
8184
func private @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>)
8285

0 commit comments

Comments
 (0)