Skip to content

Commit 151af4a

Browse files
committed
[mlir] Add a MappableContainer trait.
This is needed for downstream users to define their custom vector and tensor types that can work with the arith/math dialect. RFC https://discourse.llvm.org/t/rfc-mlir-types-with-encoding/80189
1 parent ae2e66b commit 151af4a

File tree

3 files changed

+52
-7
lines changed

3 files changed

+52
-7
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ struct IntegerTypeStorage;
3939
struct TupleTypeStorage;
4040
} // namespace detail
4141

42+
/// Type trait indicating that the type can be an operand to an elementwise op.
43+
template <typename ConcreteType>
44+
class MappableContainer
45+
: public TypeTrait::TraitBase<ConcreteType, MappableContainer> {};
46+
47+
/// Type trait indicating that the type has value semantics.
48+
template <typename ConcreteType>
49+
class ValueSemantics
50+
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
51+
4252
//===----------------------------------------------------------------------===//
4353
// FloatType
4454
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
3030
let typeName = "builtin." # typeMnemonic;
3131
}
3232

33+
//===----------------------------------------------------------------------===//
34+
// Traits
35+
//===----------------------------------------------------------------------===//
36+
37+
/// Type trait indicating that the type can be an operand to an elementwise op.
38+
def MappableContainer : NativeTypeTrait<"MappableContainer"> {
39+
let cppNamespace = "::mlir";
40+
}
41+
42+
/// Type trait indicating that the type has value semantics.
43+
def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
44+
let cppNamespace = "::mlir";
45+
}
46+
3347
//===----------------------------------------------------------------------===//
3448
// ComplexType
3549
//===----------------------------------------------------------------------===//
@@ -745,7 +759,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
745759
//===----------------------------------------------------------------------===//
746760

747761
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
748-
ShapedTypeInterface
762+
MappableContainer, ShapedTypeInterface, ValueSemantics
749763
], "TensorType"> {
750764
let summary = "Multi-dimensional array with a fixed number of dimensions";
751765
let description = [{
@@ -1049,7 +1063,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
10491063
// VectorType
10501064
//===----------------------------------------------------------------------===//
10511065

1052-
def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Type"> {
1066+
def Builtin_Vector : Builtin_Type<"Vector", "vector",
1067+
[MappableContainer, ShapedTypeInterface, ValueSemantics], "Type"> {
10531068
let summary = "Multi-dimensional SIMD vector type";
10541069
let description = [{
10551070
Syntax:

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ def HasStaticShapePred :
8989
// Whether a type is a TupleType.
9090
def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
9191

92+
// Whether a type has a MappableContainer trait.
93+
def IsMappableContainerPred : CPred<"$_self.hasTrait<MappableContainer>()">;
94+
95+
// Whether a type has a ValueSemantics trait.
96+
def HasValueSemanticsPred : CPred<"$_self.hasTrait<ValueSemantics>()">;
97+
9298
//===----------------------------------------------------------------------===//
9399
// Type definitions
94100
//===----------------------------------------------------------------------===//
@@ -403,6 +409,12 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
403409
CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
404410
]>;
405411

412+
// Mappable types with value semantics.
413+
class ValueSemanticsMappableContainerOf<list<Type> allowedTypes> :
414+
ShapedContainerType<allowedTypes,
415+
And<[HasValueSemanticsPred, IsMappableContainerPred]>,
416+
"mappable container with value semantics">;
417+
406418
// Vector types.
407419

408420
class VectorOf<list<Type> allowedTypes> :
@@ -842,10 +854,18 @@ class NestedTupleOf<list<Type> allowedTypes> :
842854
// Common type constraints
843855
//===----------------------------------------------------------------------===//
844856
// Type constraint for types that are "like" some type or set of types T, that is
845-
// they're either a T, a vector of Ts, or a tensor of Ts
857+
// they're either a T, a vector of Ts, or a tensor of Ts.
846858
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
847-
allowedType.predicate, VectorOf<[allowedType]>.predicate,
848-
TensorOf<[allowedType]>.predicate]>,
859+
allowedType.predicate,
860+
ValueSemanticsMappableContainerOf<[allowedType]>.predicate]>,
861+
name>;
862+
863+
// Type constraint for types that are "like" some type or set of types T, that is
864+
// they're either a T or a mapable container of Ts.
865+
class TypeOrValueSemanticsMappableContainer<Type allowedType, string name>
866+
: TypeConstraint<Or<[
867+
allowedType.predicate,
868+
ValueSemanticsMappableContainerOf<[allowedType]>.predicate]>,
849869
name>;
850870

851871
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
@@ -864,8 +884,8 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
864884

865885
// Type constraint for signless-integer-like types: signless integers, indices,
866886
// vectors of signless integers or indices, tensors of signless integers.
867-
def SignlessIntegerLike : TypeOrContainer<AnySignlessIntegerOrIndex,
868-
"signless-integer-like">;
887+
def SignlessIntegerLike : TypeOrValueSemanticsMappableContainer<
888+
AnySignlessIntegerOrIndex, "signless-integer-like">;
869889

870890
def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
871891
AnySignlessIntegerOrIndex,

0 commit comments

Comments
 (0)