@@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
82
82
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
83
83
Tosa_QuantizedInt, AnyFloat]>;
84
84
85
+ //===----------------------------------------------------------------------===//
86
+ // TOSA Tensor Conformance
87
+ //===----------------------------------------------------------------------===//
88
+
89
+ def HasNo0Dimensions : And<[
90
+ IsRankedTensorTypePred,
91
+ CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
92
+
93
+ class TosaTensorOf<
94
+ list<Type> allowedTypes, string summary = "tosa-conformant tensor">
95
+ : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
96
+
97
+ class TosaRankedTensorOf<
98
+ list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
99
+ : RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;
100
+
101
+ class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
102
+ : UnrankedTensorOf<allowedTypes, preds, summary>;
103
+
104
+ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
105
+ : TosaRankedTensorOf<allowedTypes,
106
+ [HasAnyRankOfPred<ranks>],
107
+ !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
108
+
85
109
//===----------------------------------------------------------------------===//
86
110
// Tensor types
87
111
//===----------------------------------------------------------------------===//
88
112
89
- def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
90
- def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
113
+ def Tosa_I1Tensor : TosaTensorOf<[I1]>;
114
+ def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
115
+ def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
91
116
92
- def Tosa_FloatTensor : TensorOf <[AnyFloat]>;
117
+ def Tosa_FloatTensor : TosaTensorOf <[AnyFloat]>;
93
118
94
119
// Either ranked or unranked tensor of TOSA supported element types.
95
- def Tosa_Tensor : TensorOf <[Tosa_AnyNumber]>;
120
+ def Tosa_Tensor : TosaTensorOf <[Tosa_AnyNumber]>;
96
121
97
122
// Must be ranked but no further constraints
98
- def Tosa_RankedTensor : RankedTensorOf <[Tosa_AnyNumber]>;
123
+ def Tosa_RankedTensor : TosaRankedTensorOf <[Tosa_AnyNumber]>;
99
124
100
125
// Any tensor element type allowed in Tosa ops.
101
126
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
102
127
AnyFloat.predicate]>, "tosa.dtype">;
103
128
104
129
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
105
- AnyTypeOf<[TensorOf <allowedTypes>, NoneType], description>;
130
+ AnyTypeOf<[TosaTensorOf <allowedTypes>, NoneType], description>;
106
131
107
132
//===----------------------------------------------------------------------===//
108
133
// Tensor types with constrained ranks.
109
134
//===----------------------------------------------------------------------===//
110
135
111
136
// Rank-0 (scalar) tensor
112
- def Tosa_ScalarTensor : TensorRankOf <[Tosa_AnyNumber], [0]>;
137
+ def Tosa_ScalarTensor : TosaTensorRankOf <[Tosa_AnyNumber], [0]>;
113
138
114
139
// We include unranked tensors as a supported type for all possible tosa
115
140
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
116
141
// they should be shape propagate used Tosa's shape inference pass and verified
117
142
// to not include any remaining unranked tensors.
118
- def Tosa_UnrankedTensor : UnrankedTensorOf <[Tosa_AnyNumber]>;
143
+ def Tosa_UnrankedTensor : TosaUnrankedTensorOf <[Tosa_AnyNumber]>;
119
144
120
- def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf <[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
121
- def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf <[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
122
- def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf <[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
123
- def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf <[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
124
- def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf <[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
145
+ def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_AnyNumber], [1] >], "1-d tosa-conformant tensor", "::mlir::TensorType">;
146
+ def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_AnyNumber], [2] >], "2-d tosa-conformant tensor", "::mlir::TensorType">;
147
+ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_AnyNumber], [3] >], "3-d tosa-conformant tensor", "::mlir::TensorType">;
148
+ def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_AnyNumber], [4] >], "4-d tosa-conformant tensor", "::mlir::TensorType">;
149
+ def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
125
150
126
151
// Ranked tensors up to given rank.
127
152
def Tosa_Tensor1Dto4D : AnyTypeOf<[
128
- Tosa_UnrankedTensor, TensorRankOf <[Tosa_AnyNumber], [1,2,3,4]>]>;
153
+ Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_AnyNumber], [1,2,3,4]>]>;
129
154
def Tosa_Tensor1Dto6D : AnyTypeOf<[
130
- Tosa_UnrankedTensor, TensorRankOf <[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
155
+ Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
131
156
132
157
def Tosa_TensorUpto4D : AnyTypeOf<[
133
- Tosa_UnrankedTensor, TensorRankOf <[Tosa_AnyNumber], [0,1,2,3,4]>]>;
158
+ Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_AnyNumber], [0,1,2,3,4]>]>;
134
159
135
160
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
136
- Tosa_UnrankedTensor, TensorRankOf <[Tosa_Int32], [0,1,2,3,4]>]>;
161
+ Tosa_UnrankedTensor, TosaTensorRankOf <[Tosa_Int32], [0,1,2,3,4]>]>;
137
162
138
163
//===----------------------------------------------------------------------===//
139
164
// Generic scalar, vector, or tensor of a particular type.
@@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
142
167
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
143
168
AnyTypeOf<types>.predicate,
144
169
VectorOf<types>.predicate,
145
- TensorOf <types>.predicate]>,
170
+ TosaTensorOf <types>.predicate]>,
146
171
description>;
147
172
148
173
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
0 commit comments