@@ -89,6 +89,9 @@ def HasStaticShapePred :
89
89
// Whether a type is a TupleType.
90
90
def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
91
91
92
+ // Whether a type has a MappableContainer trait.
93
+ def IsMappableContainerPred : CPred<"$_self.hasTrait<MappableContainer>()">;
94
+
92
95
//===----------------------------------------------------------------------===//
93
96
// Type definitions
94
97
//===----------------------------------------------------------------------===//
@@ -403,6 +406,10 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
403
406
CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
404
407
]>;
405
408
409
+ // Mappable types.
410
+ class MappableContainerOf<list<Type> allowedTypes> :
411
+ ShapedContainerType<allowedTypes, IsMappableContainerPred, "mapable container">;
412
+
406
413
// Vector types.
407
414
408
415
class VectorOf<list<Type> allowedTypes> :
@@ -842,10 +849,15 @@ class NestedTupleOf<list<Type> allowedTypes> :
842
849
// Common type constraints
843
850
//===----------------------------------------------------------------------===//
844
851
// Type constraint for types that are "like" some type or set of types T, that is
845
- // they're either a T, a vector of Ts, or a tensor of Ts
852
+ // they're either a T, a vector of Ts, or a tensor of Ts.
846
853
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
847
- allowedType.predicate, VectorOf<[allowedType]>.predicate,
848
- TensorOf<[allowedType]>.predicate]>,
854
+ allowedType.predicate, MappableContainerOf<[allowedType]>.predicate]>,
855
+ name>;
856
+
857
+ // Type constraint for types that are "like" some type or set of types T, that is
858
+ // they're either a T or a mapable container of Ts.
859
+ class TypeOrMappableContainer<Type allowedType, string name> : TypeConstraint<Or<[
860
+ allowedType.predicate, MappableContainerOf<[allowedType]>.predicate]>,
849
861
name>;
850
862
851
863
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
@@ -864,7 +876,7 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
864
876
865
877
// Type constraint for signless-integer-like types: signless integers, indices,
866
878
// vectors of signless integers or indices, tensors of signless integers.
867
- def SignlessIntegerLike : TypeOrContainer <AnySignlessIntegerOrIndex,
879
+ def SignlessIntegerLike : TypeOrMappableContainer <AnySignlessIntegerOrIndex,
868
880
"signless-integer-like">;
869
881
870
882
def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
0 commit comments