Skip to content

Commit a79af6e

Browse files
committed
[mlir] Add a ValueSemantics trait.
We need to distinguish ShapedTypes with and without value semantics. This is needed for downstream users to define their custom vector and tensor types that can work with the arith/math dialect. RFC https://discourse.llvm.org/t/rfc-mlir-types-with-encoding/80189
1 parent dcebe29 commit a79af6e

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ struct IntegerTypeStorage;
3939
struct TupleTypeStorage;
4040
} // namespace detail
4141

42+
/// Type trait indicating that the type has value semantics.
43+
template <typename ConcreteType>
44+
class ValueSemantics
45+
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
46+
4247
//===----------------------------------------------------------------------===//
4348
// FloatType
4449
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
3030
let typeName = "builtin." # typeMnemonic;
3131
}
3232

33+
//===----------------------------------------------------------------------===//
34+
// Traits
35+
//===----------------------------------------------------------------------===//
36+
37+
/// Type trait indicating that the type has value semantics.
38+
def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
39+
let cppNamespace = "::mlir";
40+
}
41+
3342
//===----------------------------------------------------------------------===//
3443
// ComplexType
3544
//===----------------------------------------------------------------------===//
@@ -745,7 +754,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
745754
//===----------------------------------------------------------------------===//
746755

747756
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
748-
ShapedTypeInterface
757+
ShapedTypeInterface, ValueSemantics
749758
], "TensorType"> {
750759
let summary = "Multi-dimensional array with a fixed number of dimensions";
751760
let description = [{
@@ -1001,7 +1010,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
10011010
//===----------------------------------------------------------------------===//
10021011

10031012
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
1004-
ShapedTypeInterface
1013+
ShapedTypeInterface, ValueSemantics
10051014
], "TensorType"> {
10061015
let summary = "Multi-dimensional array with unknown dimensions";
10071016
let description = [{
@@ -1049,7 +1058,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
10491058
// VectorType
10501059
//===----------------------------------------------------------------------===//
10511060

1052-
def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Type"> {
1061+
def Builtin_Vector : Builtin_Type<"Vector", "vector",
1062+
[ShapedTypeInterface, ValueSemantics], "Type"> {
10531063
let summary = "Multi-dimensional SIMD vector type";
10541064
let description = [{
10551065
Syntax:

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def HasStaticShapePred :
8989
// Whether a type is a TupleType.
9090
def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
9191

92+
// Whether a type has a ValueSemantics trait.
93+
def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;
94+
9295
//===----------------------------------------------------------------------===//
9396
// Type definitions
9497
//===----------------------------------------------------------------------===//
@@ -403,6 +406,11 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
403406
CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
404407
]>;
405408

409+
// Container with value semantics.
410+
class ValueSemanticsContainerOf<list<Type> allowedTypes> :
411+
ShapedContainerType<allowedTypes, HasValueSemanticsPred,
412+
"container with value semantics">;
413+
406414
// Vector types.
407415

408416
class VectorOf<list<Type> allowedTypes> :
@@ -842,10 +850,18 @@ class NestedTupleOf<list<Type> allowedTypes> :
842850
// Common type constraints
843851
//===----------------------------------------------------------------------===//
844852
// Type constraint for types that are "like" some type or set of types T, that is
845-
// they're either a T, a vector of Ts, or a tensor of Ts
853+
// they're either a T, a vector of Ts, or a tensor of Ts.
846854
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
847-
allowedType.predicate, VectorOf<[allowedType]>.predicate,
848-
TensorOf<[allowedType]>.predicate]>,
855+
allowedType.predicate,
856+
ValueSemanticsContainerOf<[allowedType]>.predicate]>,
857+
name>;
858+
859+
// Type constraint for types that are "like" some type or set of types T, that is
860+
// they're either a T or a mapable container of Ts.
861+
class TypeOrValueSemanticsContainer<Type allowedType, string name>
862+
: TypeConstraint<Or<[
863+
allowedType.predicate,
864+
ValueSemanticsContainerOf<[allowedType]>.predicate]>,
849865
name>;
850866

851867
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
@@ -864,8 +880,8 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
864880

865881
// Type constraint for signless-integer-like types: signless integers, indices,
866882
// vectors of signless integers or indices, tensors of signless integers.
867-
def SignlessIntegerLike : TypeOrContainer<AnySignlessIntegerOrIndex,
868-
"signless-integer-like">;
883+
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
884+
AnySignlessIntegerOrIndex, "signless-integer-like">;
869885

870886
def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
871887
AnySignlessIntegerOrIndex,

0 commit comments

Comments
 (0)