@@ -89,6 +89,12 @@ def HasStaticShapePred :
89
89
// Whether a type is a TupleType.
90
90
def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
91
91
92
+ // Whether a type has a MappableContainer trait.
93
+ def IsMappableContainerPred : CPred<"$_self.hasTrait<MappableContainer>()">;
94
+
95
+ // Whether a type has a ValueSemantics trait.
96
+ def HasValueSemanticsPred : CPred<"$_self.hasTrait<ValueSemantics>()">;
97
+
92
98
//===----------------------------------------------------------------------===//
93
99
// Type definitions
94
100
//===----------------------------------------------------------------------===//
@@ -403,6 +409,12 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
403
409
CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
404
410
]>;
405
411
412
+ // Mappable types with value semantics.
413
+ class ValueSemanticsMappableContainerOf<list<Type> allowedTypes> :
414
+ ShapedContainerType<allowedTypes,
415
+ And<[HasValueSemanticsPred, IsMappableContainerPred]>,
416
+ "mappable container with value semantics">;
417
+
406
418
// Vector types.
407
419
408
420
class VectorOf<list<Type> allowedTypes> :
@@ -842,10 +854,18 @@ class NestedTupleOf<list<Type> allowedTypes> :
842
854
// Common type constraints
843
855
//===----------------------------------------------------------------------===//
844
856
// Type constraint for types that are "like" some type or set of types T, that is
845
- // they're either a T, a vector of Ts, or a tensor of Ts
857
+ // they're either a T, a vector of Ts, or a tensor of Ts.
846
858
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
847
- allowedType.predicate, VectorOf<[allowedType]>.predicate,
848
- TensorOf<[allowedType]>.predicate]>,
859
+ allowedType.predicate,
860
+ ValueSemanticsMappableContainerOf<[allowedType]>.predicate]>,
861
+ name>;
862
+
863
+ // Type constraint for types that are "like" some type or set of types T, that is
864
+ // they're either a T or a mapable container of Ts.
865
+ class TypeOrValueSemanticsMappableContainer<Type allowedType, string name>
866
+ : TypeConstraint<Or<[
867
+ allowedType.predicate,
868
+ ValueSemanticsMappableContainerOf<[allowedType]>.predicate]>,
849
869
name>;
850
870
851
871
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
@@ -864,8 +884,8 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
864
884
865
885
// Type constraint for signless-integer-like types: signless integers, indices,
866
886
// vectors of signless integers or indices, tensors of signless integers.
867
- def SignlessIntegerLike : TypeOrContainer<AnySignlessIntegerOrIndex,
868
- "signless-integer-like">;
887
+ def SignlessIntegerLike : TypeOrValueSemanticsMappableContainer<
888
+ AnySignlessIntegerOrIndex, "signless-integer-like">;
869
889
870
890
def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
871
891
AnySignlessIntegerOrIndex,
0 commit comments