Skip to content

Commit 2674e4c

Browse files
committed
[mlir] Add a MappableContainer trait.
This is needed for downstream users to define their custom vector and tensor types that can work with the arith/math dialect. RFC https://discourse.llvm.org/t/rfc-mlir-types-with-encoding/80189
1 parent ad7aeb0 commit 2674e4c

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ struct IntegerTypeStorage;
3939
struct TupleTypeStorage;
4040
} // namespace detail
4141

42+
/// Type trait indicating that the type can be an operand to an elementwise op.
43+
template <typename ConcreteType>
44+
class MappableContainer
45+
: public TypeTrait::TraitBase<ConcreteType, MappableContainer> {};
46+
4247
//===----------------------------------------------------------------------===//
4348
// FloatType
4449
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
3030
let typeName = "builtin." # typeMnemonic;
3131
}
3232

33+
//===----------------------------------------------------------------------===//
34+
// Traits
35+
//===----------------------------------------------------------------------===//
36+
37+
/// Type trait indicating that the type can be an operand to an elementwise op.
38+
def MappableContainer : NativeTypeTrait<"MappableContainer"> {
39+
let cppNamespace = "::mlir";
40+
}
41+
3342
//===----------------------------------------------------------------------===//
3443
// ComplexType
3544
//===----------------------------------------------------------------------===//
@@ -745,7 +754,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
745754
//===----------------------------------------------------------------------===//
746755

747756
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
748-
ShapedTypeInterface
757+
MappableContainer, ShapedTypeInterface
749758
], "TensorType"> {
750759
let summary = "Multi-dimensional array with a fixed number of dimensions";
751760
let description = [{
@@ -1049,7 +1058,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
10491058
// VectorType
10501059
//===----------------------------------------------------------------------===//
10511060

1052-
def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Type"> {
1061+
def Builtin_Vector : Builtin_Type<"Vector", "vector",
1062+
[MappableContainer, ShapedTypeInterface], "Type"> {
10531063
let summary = "Multi-dimensional SIMD vector type";
10541064
let description = [{
10551065
Syntax:

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def HasStaticShapePred :
8989
// Whether a type is a TupleType.
9090
def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
9191

92+
// Whether a type has a MappableContainer trait.
93+
def IsMappableContainerPred : CPred<"$_self.hasTrait<MappableContainer>()">;
94+
9295
//===----------------------------------------------------------------------===//
9396
// Type definitions
9497
//===----------------------------------------------------------------------===//
@@ -403,6 +406,10 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
403406
CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
404407
]>;
405408

409+
// Mappable types.
410+
class MappableContainerOf<list<Type> allowedTypes> :
411+
ShapedContainerType<allowedTypes, IsMappableContainerPred, "mapable container">;
412+
406413
// Vector types.
407414

408415
class VectorOf<list<Type> allowedTypes> :
@@ -842,10 +849,15 @@ class NestedTupleOf<list<Type> allowedTypes> :
842849
// Common type constraints
843850
//===----------------------------------------------------------------------===//
844851
// Type constraint for types that are "like" some type or set of types T, that is
845-
// they're either a T, a vector of Ts, or a tensor of Ts
852+
// they're either a T, a vector of Ts, or a tensor of Ts.
846853
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
847-
allowedType.predicate, VectorOf<[allowedType]>.predicate,
848-
TensorOf<[allowedType]>.predicate]>,
854+
allowedType.predicate, MappableContainerOf<[allowedType]>.predicate]>,
855+
name>;
856+
857+
// Type constraint for types that are "like" some type or set of types T, that is
858+
// they're either a T or a mapable container of Ts.
859+
class TypeOrMappableContainer<Type allowedType, string name> : TypeConstraint<Or<[
860+
allowedType.predicate, MappableContainerOf<[allowedType]>.predicate]>,
849861
name>;
850862

851863
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
@@ -864,7 +876,7 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
864876

865877
// Type constraint for signless-integer-like types: signless integers, indices,
866878
// vectors of signless integers or indices, tensors of signless integers.
867-
def SignlessIntegerLike : TypeOrContainer<AnySignlessIntegerOrIndex,
879+
def SignlessIntegerLike : TypeOrMappableContainer<AnySignlessIntegerOrIndex,
868880
"signless-integer-like">;
869881

870882
def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<

0 commit comments

Comments
 (0)