Skip to content

Commit 568845a

Browse files
authored
[mlir] Add a ValueSemantics trait. (#99493)
We need to distinguish ShapedTypes with and without value semantics. This is needed for downstream users to define their custom vector and tensor types that can work with the arith/math dialect. RFC https://discourse.llvm.org/t/rfc-mlir-types-with-encoding/80189
1 parent dcebe29 commit 568845a

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ struct IntegerTypeStorage;
3939
struct TupleTypeStorage;
4040
} // namespace detail
4141

42+
/// Type trait indicating that the type has value semantics.
43+
template <typename ConcreteType>
44+
class ValueSemantics
45+
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
46+
4247
//===----------------------------------------------------------------------===//
4348
// FloatType
4449
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
3030
let typeName = "builtin." # typeMnemonic;
3131
}
3232

33+
//===----------------------------------------------------------------------===//
34+
// Traits
35+
//===----------------------------------------------------------------------===//
36+
37+
/// Type trait indicating that the type has value semantics.
38+
def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
39+
let cppNamespace = "::mlir";
40+
}
41+
3342
//===----------------------------------------------------------------------===//
3443
// ComplexType
3544
//===----------------------------------------------------------------------===//
@@ -745,7 +754,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
745754
//===----------------------------------------------------------------------===//
746755

747756
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
748-
ShapedTypeInterface
757+
ShapedTypeInterface, ValueSemantics
749758
], "TensorType"> {
750759
let summary = "Multi-dimensional array with a fixed number of dimensions";
751760
let description = [{
@@ -1001,7 +1010,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
10011010
//===----------------------------------------------------------------------===//
10021011

10031012
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
1004-
ShapedTypeInterface
1013+
ShapedTypeInterface, ValueSemantics
10051014
], "TensorType"> {
10061015
let summary = "Multi-dimensional array with unknown dimensions";
10071016
let description = [{
@@ -1049,7 +1058,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
10491058
// VectorType
10501059
//===----------------------------------------------------------------------===//
10511060

1052-
def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Type"> {
1061+
def Builtin_Vector : Builtin_Type<"Vector", "vector",
1062+
[ShapedTypeInterface, ValueSemantics], "Type"> {
10531063
let summary = "Multi-dimensional SIMD vector type";
10541064
let description = [{
10551065
Syntax:

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def HasStaticShapePred :
8989
// Whether a type is a TupleType.
9090
def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
9191

92+
// Whether a type has a ValueSemantics trait.
93+
def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;
94+
9295
//===----------------------------------------------------------------------===//
9396
// Type definitions
9497
//===----------------------------------------------------------------------===//
@@ -403,6 +406,11 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
403406
CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
404407
]>;
405408

409+
// Container with value semantics.
410+
class ValueSemanticsContainerOf<list<Type> allowedTypes> :
411+
ShapedContainerType<allowedTypes, HasValueSemanticsPred,
412+
"container with value semantics">;
413+
406414
// Vector types.
407415

408416
class VectorOf<list<Type> allowedTypes> :
@@ -842,10 +850,18 @@ class NestedTupleOf<list<Type> allowedTypes> :
842850
// Common type constraints
843851
//===----------------------------------------------------------------------===//
844852
// Type constraint for types that are "like" some type or set of types T, that is
845-
// they're either a T, a vector of Ts, or a tensor of Ts
853+
// they're either a T, a vector of Ts, or a tensor of Ts.
846854
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
847-
allowedType.predicate, VectorOf<[allowedType]>.predicate,
848-
TensorOf<[allowedType]>.predicate]>,
855+
allowedType.predicate,
856+
ValueSemanticsContainerOf<[allowedType]>.predicate]>,
857+
name>;
858+
859+
// Type constraint for types that are "like" some type or set of types T, that is
860+
// they're either a T or a mapable container of Ts.
861+
class TypeOrValueSemanticsContainer<Type allowedType, string name>
862+
: TypeConstraint<Or<[
863+
allowedType.predicate,
864+
ValueSemanticsContainerOf<[allowedType]>.predicate]>,
849865
name>;
850866

851867
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
@@ -864,8 +880,8 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
864880

865881
// Type constraint for signless-integer-like types: signless integers, indices,
866882
// vectors of signless integers or indices, tensors of signless integers.
867-
def SignlessIntegerLike : TypeOrContainer<AnySignlessIntegerOrIndex,
868-
"signless-integer-like">;
883+
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
884+
AnySignlessIntegerOrIndex, "signless-integer-like">;
869885

870886
def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
871887
AnySignlessIntegerOrIndex,

0 commit comments

Comments
 (0)